diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..aab78d46e1adc78f16a8d709fdacfd31a4b706bf --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +**/.idea +*~ +*.swp +*.o +*.so +*.pyc +build +docs +dist +*.egg-info/ +**/.vscode +.clang-tidy diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 9f358a4addefcab294b83e4282bfef1f9625a249..0000000000000000000000000000000000000000 --- a/LICENSE +++ /dev/null @@ -1 +0,0 @@ -123456 diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000000000000000000000000000000000..2ba72dc45b2e613cf6b2d40afb52883c66dcf5f0 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,19 @@ +# +# MANIFEST.in +# +# Manifest template for creating the dlib source distribution. + +include MANIFEST.in +include setup.py +include README.md + +# sources +recursive-include dlib ** +recursive-include python_examples *.txt *.py +recursive-include tools/python ** + +prune tools/python/build* +prune dlib/cmake_utils/*/build* +prune dlib/test + +global-exclude *.pyc diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d1c6d7fc32d51664b20e27b0c8d51db442a7d635 --- /dev/null +++ b/README.md @@ -0,0 +1,78 @@ +# dlib C++ library [![GitHub Actions C++ Status](https://github.com/davisking/dlib/actions/workflows/build_cpp.yml/badge.svg)](https://github.com/davisking/dlib/actions/workflows/build_cpp.yml) [![GitHub Actions Python Status](https://github.com/davisking/dlib/actions/workflows/build_python.yml/badge.svg)](https://github.com/davisking/dlib/actions/workflows/build_python.yml) + +Dlib is a modern C++ toolkit containing machine learning algorithms and tools for creating complex software in C++ to solve real world problems. +See [http://dlib.net](http://dlib.net) for the main project documentation and API reference. + +## Compiling dlib C++ example programs + +Go into the examples folder and type: + +DCU编译之前,需要准备编译环境 +[environment_prepare](environment_prepare.md) + +```shell +mkdir build; cd build; +cmake .. -DCMAKE_CXX_COMPILER=/opt/dtk/bin/hipcc +cmake --build . -j16 --verbose +``` + +That will build all the examples. +If you have a CPU that supports AVX instructions then turn them on like this: + +```shell +mkdir build; cd build; cmake .. -DUSE_AVX_INSTRUCTIONS=1; cmake --build . +``` + +Doing so will make some things run faster. + +Finally, Visual Studio users should usually do everything in 64bit mode. By default Visual Studio is 32bit, both in its outputs and its own execution, so you have to explicitly +tell it to use 64bits. Since it's not the 1990s anymore you probably want to use 64bits. Do that with a cmake invocation like this: + +```shell +cmake .. -G "Visual Studio 14 2015 Win64" -T host=x64 +``` + +## Compiling your own C++ programs that use dlib + +The examples folder has a [CMake tutorial](https://github.com/davisking/dlib/blob/master/examples/CMakeLists.txt) that tells you what to do. There are also additional instructions +on the [dlib web site](http://dlib.net/compile.html). + +```shell +vcpkg install dlib +``` + +## Compiling dlib Python API + +Before you can run the Python example programs you must compile dlib. Type: + +```shell +python setup.py install --set CMAKE_CXX_COMPILER=/opt/dtk/bin/hipcc +``` + +## Running the unit test suite + +Type the following to compile and run the dlib unit test suite: + +```shell +cd dlib/test +mkdir build +cd build + +cmake .. -DCMAKE_CXX_COMPILER=/opt/dtk/bin/hipcc +cmake --build . -j16 --config Release --verbose + +./dtest --runall +./dtest -d --test_dnn +``` + +Note that on windows your compiler might put the test executable in a subfolder called `Release`. If that's the case then you have to go to that folder before running the test. + +This library is licensed under the Boost Software License, which can be found in [dlib/LICENSE.txt](https://github.com/davisking/dlib/blob/master/dlib/LICENSE.txt). The long and +short of the license is that you can use dlib however you like, even in closed source commercial software. + +## dlib sponsors + +This research is based in part upon work supported by the Office of the Director of National Intelligence (ODNI), Intelligence Advanced Research Projects Activity (IARPA) under +contract number 2014-14071600010. The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official +policies or endorsements, either expressed or implied, of ODNI, IARPA, or the U.S. Government. + diff --git a/README_HIP.md b/README_HIP.md new file mode 100644 index 0000000000000000000000000000000000000000..3bfb999c2e34ee53b4093fe87897327d4da426f7 --- /dev/null +++ b/README_HIP.md @@ -0,0 +1,67 @@ +# DLIB + +## 环境配置 + +使用DCU编译之前,需要准备编译环境。参考 +[environment prepare](environment_prepare.md) + +## 使用源码安装 + +### 编译环境准备(以dtk-23.04版本为例) + +- 拉取 apex 代码 + + ``` + git clone -b dtk-23.04 http://developer.hpccube.com/codes/aicomponent/dlib.git + ``` +- 在[开发者社区](https://developer.hpccube.com/tool/#sdk) DCU Toolkit 中下载 DTK-23.04 解压至 /opt/ 路径下,并建立软链接 + + ``` + cd /opt && ln -s dtk-23.04 dtk + ``` + +- 导入环境变量以及安装必要依赖库 + + ```shell + source /opt/dtk/env.sh + ``` + +### 编译安装 + +#### 使用cmake编译 + +```shell +mkdir build; cd build; +cmake .. -DCMAKE_CXX_COMPILER=/opt/dtk/bin/hipcc + +cmake --build . -j16 #--verbose +``` + +#### 单元测试 + +Running the unit test suite. +Type the following to compile and run the dlib unit test suite: + +- 编译单元测试 + +```shell +cd dlib/test +mkdir build +cd build + +cmake .. -DCMAKE_CXX_COMPILER=/opt/dtk/bin/hipcc +cmake --build . -j16 --config Release #--verbose +``` + +- 进行单元测试 + +```shell +./dtest --runall # 全部测试 +./dtest -d --test_dnn # 测试单个测试单元 +``` + +#### 编译 dlib Python API + +```shell +python setup.py install --set CMAKE_CXX_COMPILER=/opt/dtk/bin/hipcc +``` diff --git a/dlib/CMakeLists.txt b/dlib/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..efef50af13be6e2ced87575f38d221b1249ee477 --- /dev/null +++ b/dlib/CMakeLists.txt @@ -0,0 +1,934 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# + + +cmake_minimum_required(VERSION 3.8.0) + +set(CMAKE_DISABLE_SOURCE_CHANGES ON) +set(CMAKE_DISABLE_IN_SOURCE_BUILD ON) + +if (POLICY CMP0077) + cmake_policy(SET CMP0077 NEW) +endif () + +project(dlib) + +set(CPACK_PACKAGE_NAME "dlib") +set(CPACK_PACKAGE_VERSION_MAJOR "19") +set(CPACK_PACKAGE_VERSION_MINOR "24") +set(CPACK_PACKAGE_VERSION_PATCH "99") +set(VERSION ${CPACK_PACKAGE_VERSION_MAJOR}.${CPACK_PACKAGE_VERSION_MINOR}.${CPACK_PACKAGE_VERSION_PATCH}) +# Only print these messages once, even if dlib is added multiple times via add_subdirectory() +if (NOT TARGET dlib) + message(STATUS "Using CMake version: ${CMAKE_VERSION}") + message(STATUS "Compiling dlib version: ${VERSION}") +endif () + + +include(cmake_utils/set_compiler_specific_options.cmake) +## set dtk +#set(CMAKE_CXX_COMPILER /opt/dtk/bin/hipcc) +#message(STATUS "dlib CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER_ID}") +#message(STATUS "dlib CMAKE_C_COMPILER: ${CMAKE_C_COMPILER_ID}") + +# Adhere to GNU filesystem layout conventions +message("GNUInstallDirs: " ${GNUInstallDirs}) +include(GNUInstallDirs) + +if (POLICY CMP0075) + cmake_policy(SET CMP0075 NEW) +endif () + +# default to a Release build (except if CMAKE_BUILD_TYPE is set) +include(cmake_utils/release_build_by_default) + +# Set DLIB_VERSION in the including CMake file so they can use it to do whatever they want. +get_directory_property(has_parent PARENT_DIRECTORY) +if (has_parent) + set(DLIB_VERSION ${VERSION} PARENT_SCOPE) + if (NOT DEFINED DLIB_IN_PROJECT_BUILD) + set(DLIB_IN_PROJECT_BUILD true) + endif () +endif () + + +if (COMMAND pybind11_add_module AND MSVC) + # True when building a python extension module using Visual Studio. We care + # about this because a huge number of windows users have broken systems, and + # in particular, they have broken or incompatibly installed copies of things + # like libjpeg or libpng. So if we detect we are in this mode we will never + # ever link to those libraries. Instead, we link to the copy included with + # dlib. + set(BUILDING_PYTHON_IN_MSVC true) +else () + set(BUILDING_PYTHON_IN_MSVC false) +endif () + +if (DLIB_IN_PROJECT_BUILD) + + # Check if we are being built as part of a pybind11 module. + if (COMMAND pybind11_add_module) + set(CMAKE_POSITION_INDEPENDENT_CODE True) + if (CMAKE_COMPILER_IS_GNUCXX) + # Just setting CMAKE_POSITION_INDEPENDENT_CODE should be enough to set + # -fPIC for GCC but sometimes it still doesn't get set, so make sure it + # does. + add_definitions("-fPIC") + endif () + # Make DLIB_ASSERT statements not abort the python interpreter, but just return an error. + list(APPEND active_preprocessor_switches "-DDLIB_NO_ABORT_ON_2ND_FATAL_ERROR") + endif () + + # DLIB_IN_PROJECT_BUILD==true means you are using dlib by invoking + # add_subdirectory(dlib) in the parent project. In this case, we always want + # to build dlib as a static library so the parent project doesn't need to + # deal with some random dlib shared library file. It is much better to + # statically compile dlib into the parent project. So the following bit of + # CMake ensures that happens. However, we have to take care to compile dlib + # with position independent code if appropriate (i.e. if the parent project + # is a shared library). + if (BUILD_SHARED_LIBS) + if (CMAKE_COMPILER_IS_GNUCXX) + # Just setting CMAKE_POSITION_INDEPENDENT_CODE should be enough to set + # -fPIC for GCC but sometimes it still doesn't get set, so make sure it + # does. + add_definitions("-fPIC") + endif () + set(CMAKE_POSITION_INDEPENDENT_CODE true) + endif () + + # Tell cmake to build dlib as a static library + set(BUILD_SHARED_LIBS false) + +elseif (BUILD_SHARED_LIBS) + if (MSVC) + message(FATAL_ERROR "Building dlib as a standalone dll is not supported when using Visual Studio. You are highly encouraged to use static linking instead. See https://github.com/davisking/dlib/issues/1483 for a discussion.") + endif () +endif () + + +if (CMAKE_VERSION VERSION_LESS "3.9.0") + # Set only because there are old target_link_libraries() statements in the + # FindCUDA.cmake file that comes with CMake that error out if the new behavior + # is used. In newer versions of CMake we can instead set ROCM_LINK_LIBRARIES_KEYWORD which fixes this issue. + cmake_policy(SET CMP0023 OLD) +else () + set(ROCM_LINK_LIBRARIES_KEYWORD PUBLIC) +endif () + + +macro(enable_preprocessor_switch option_name) + list(APPEND active_preprocessor_switches "-D${option_name}") +endmacro() + +macro(disable_preprocessor_switch option_name) + if (active_preprocessor_switches) + list(REMOVE_ITEM active_preprocessor_switches "-D${option_name}") + endif () +endmacro() + +macro(toggle_preprocessor_switch option_name) + if (${option_name}) + enable_preprocessor_switch(${option_name}) + else () + disable_preprocessor_switch(${option_name}) + endif () +endmacro() + + +# Suppress superfluous randlib warnings about libdlib.a having no symbols on MacOSX. +if (APPLE) + set(CMAKE_C_ARCHIVE_CREATE " Scr ") + set(CMAKE_CXX_ARCHIVE_CREATE " Scr ") + set(CMAKE_C_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") + set(CMAKE_CXX_ARCHIVE_FINISH " -no_warning_for_no_symbols -c ") +endif () + +# Don't try to call add_library(dlib) and setup dlib's stuff if it has already +# been done by some other part of the current cmake project. We do this +# because it avoids getting warnings/errors about cmake policy CMP0002. This +# happens when a project tries to call add_subdirectory() on dlib more than +# once. This most often happens when the top level of a project depends on two +# or more other things which both depend on dlib. +if (NOT TARGET dlib) + + set(DLIB_ISO_CPP_ONLY_STR + "Enable this if you don't want to compile any non-ISO C++ code (i.e. you don't use any of the API Wrappers)") + set(DLIB_NO_GUI_SUPPORT_STR + "Enable this if you don't want to compile any of the dlib GUI code") + set(DLIB_ENABLE_STACK_TRACE_STR + "Enable this if you want to turn on the DLIB_STACK_TRACE macros") + set(DLIB_USE_BLAS_STR + "Disable this if you don't want to use a BLAS library") + set(DLIB_USE_LAPACK_STR + "Disable this if you don't want to use a LAPACK library") + set(DLIB_USE_ROCM_STR + "Disable this if you don't want to use ROCM") + set(DLIB_USE_ROCM_COMPUTE_CAPABILITIES_STR + "Set this to a comma-separated list of ROCM compute capabilities") + set(DLIB_USE_MKL_SEQUENTIAL_STR + "Enable this if you have MKL installed and want to use the sequential version instead of the multi-core version.") + set(DLIB_USE_MKL_WITH_TBB_STR + "Enable this if you have MKL installed and want to use the tbb version instead of the openmp version.") + set(DLIB_PNG_SUPPORT_STR + "Disable this if you don't want to link against libpng") + set(DLIB_GIF_SUPPORT_STR + "Disable this if you don't want to link against libgif") + set(DLIB_JPEG_SUPPORT_STR + "Disable this if you don't want to link against libjpeg") + set(DLIB_WEBP_SUPPORT_STR + "Disable this if you don't want to link against libwebp") + set(DLIB_LINK_WITH_SQLITE3_STR + "Disable this if you don't want to link against sqlite3") + #set (DLIB_USE_FFTW_STR "Disable this if you don't want to link against fftw" ) + set(DLIB_USE_MKL_FFT_STR + "Disable this is you don't want to use the MKL DFTI FFT implementation") + set(DLIB_ENABLE_ASSERTS_STR + "Enable this if you want to turn on the DLIB_ASSERT macro") + set(DLIB_USE_FFMPEG_STR + "Disable this if you don't want to use the FFMPEG library") + + option(DLIB_ENABLE_ASSERTS ${DLIB_ENABLE_ASSERTS_STR} OFF) + option(DLIB_ISO_CPP_ONLY ${DLIB_ISO_CPP_ONLY_STR} OFF) + toggle_preprocessor_switch(DLIB_ISO_CPP_ONLY) + option(DLIB_NO_GUI_SUPPORT ${DLIB_NO_GUI_SUPPORT_STR} OFF) + toggle_preprocessor_switch(DLIB_NO_GUI_SUPPORT) + option(DLIB_ENABLE_STACK_TRACE ${DLIB_ENABLE_STACK_TRACE_STR} OFF) + toggle_preprocessor_switch(DLIB_ENABLE_STACK_TRACE) + option(DLIB_USE_MKL_SEQUENTIAL ${DLIB_USE_MKL_SEQUENTIAL_STR} OFF) + option(DLIB_USE_MKL_WITH_TBB ${DLIB_USE_MKL_WITH_TBB_STR} OFF) + + if (DLIB_ENABLE_ASSERTS) + # Set these variables so they are set in the config.h.in file when dlib + # is installed. + set(DLIB_DISABLE_ASSERTS false) + set(ENABLE_ASSERTS true) + enable_preprocessor_switch(ENABLE_ASSERTS) + disable_preprocessor_switch(DLIB_DISABLE_ASSERTS) + else () + # Set these variables so they are set in the config.h.in file when dlib + # is installed. + set(DLIB_DISABLE_ASSERTS true) + set(ENABLE_ASSERTS false) + disable_preprocessor_switch(ENABLE_ASSERTS) + # Never force the asserts off when doing an in project build. The only + # time this matters is when using visual studio. The visual studio IDE + # has a drop down that lets the user select either release or debug + # builds. The DLIB_ASSERT macro is setup to enable/disable automatically + # based on this drop down (via preprocessor magic). However, if + # DLIB_DISABLE_ASSERTS is defined it permanently disables asserts no + # matter what, which would defeat the visual studio drop down. So here + # we make a point to not do that kind of severe disabling when in a + # project build. It should also be pointed out that DLIB_DISABLE_ASSERTS + # is only needed when building and installing dlib as a separately + # installed library. It doesn't matter when doing an in project build. + if (NOT DLIB_IN_PROJECT_BUILD) + enable_preprocessor_switch(DLIB_DISABLE_ASSERTS) + endif () + endif () + + if (DLIB_ISO_CPP_ONLY) + option(DLIB_JPEG_SUPPORT ${DLIB_JPEG_SUPPORT_STR} OFF) + option(DLIB_LINK_WITH_SQLITE3 ${DLIB_LINK_WITH_SQLITE3_STR} OFF) + option(DLIB_USE_BLAS ${DLIB_USE_BLAS_STR} OFF) + option(DLIB_USE_LAPACK ${DLIB_USE_LAPACK_STR} OFF) + option(DLIB_USE_ROCM ${DLIB_USE_ROCM_STR} OFF) + option(DLIB_PNG_SUPPORT ${DLIB_PNG_SUPPORT_STR} OFF) + option(DLIB_GIF_SUPPORT ${DLIB_GIF_SUPPORT_STR} OFF) + option(DLIB_WEBP_SUPPORT ${DLIB_WEBP_SUPPORT_STR} OFF) + #option(DLIB_USE_FFTW ${DLIB_USE_FFTW_STR} OFF) + option(DLIB_USE_MKL_FFT ${DLIB_USE_MKL_FFT_STR} OFF) + option(DLIB_USE_FFMPEG ${DLIB_USE_FFMPEG_STR} OFF) + else () + option(DLIB_JPEG_SUPPORT ${DLIB_JPEG_SUPPORT_STR} ON) + option(DLIB_WEBP_SUPPORT ${DLIB_WEBP_SUPPORT_STR} ON) + option(DLIB_LINK_WITH_SQLITE3 ${DLIB_LINK_WITH_SQLITE3_STR} ON) + option(DLIB_USE_BLAS ${DLIB_USE_BLAS_STR} ON) + option(DLIB_USE_LAPACK ${DLIB_USE_LAPACK_STR} ON) + option(DLIB_USE_ROCM ${DLIB_USE_ROCM_STR} ON) + option(DLIB_PNG_SUPPORT ${DLIB_PNG_SUPPORT_STR} ON) + option(DLIB_GIF_SUPPORT ${DLIB_GIF_SUPPORT_STR} ON) + #option(DLIB_USE_FFTW ${DLIB_USE_FFTW_STR} ON) + option(DLIB_USE_MKL_FFT ${DLIB_USE_MKL_FFT_STR} ON) + option(DLIB_USE_FFMPEG ${DLIB_USE_FFMPEG_STR} ON) + endif () + + toggle_preprocessor_switch(DLIB_JPEG_SUPPORT) + toggle_preprocessor_switch(DLIB_WEBP_SUPPORT) + toggle_preprocessor_switch(DLIB_USE_BLAS) + toggle_preprocessor_switch(DLIB_USE_LAPACK) + toggle_preprocessor_switch(DLIB_USE_ROCM) + toggle_preprocessor_switch(DLIB_PNG_SUPPORT) + toggle_preprocessor_switch(DLIB_GIF_SUPPORT) + #toggle_preprocessor_switch(DLIB_USE_FFTW) + toggle_preprocessor_switch(DLIB_USE_MKL_FFT) + toggle_preprocessor_switch(DLIB_USE_FFMPEG) + + + set(source_files + base64/base64_kernel_1.cpp + bigint/bigint_kernel_1.cpp + bigint/bigint_kernel_2.cpp + bit_stream/bit_stream_kernel_1.cpp + entropy_decoder/entropy_decoder_kernel_1.cpp + entropy_decoder/entropy_decoder_kernel_2.cpp + entropy_encoder/entropy_encoder_kernel_1.cpp + entropy_encoder/entropy_encoder_kernel_2.cpp + md5/md5_kernel_1.cpp + tokenizer/tokenizer_kernel_1.cpp + unicode/unicode.cpp + test_for_odr_violations.cpp + ) + + set(dlib_needed_public_libraries) + set(dlib_needed_public_includes) + set(dlib_needed_public_cflags) + set(dlib_needed_public_ldflags) + set(dlib_needed_private_libraries) + set(dlib_needed_private_includes) + + message(STATUS "DLIB_ISO_CPP_ONLY: " ${DLIB_ISO_CPP_ONLY}) + if (DLIB_ISO_CPP_ONLY) + add_library(dlib ${source_files}) + else () + + set(source_files ${source_files} + sockets/sockets_kernel_1.cpp + bsp/bsp.cpp + dir_nav/dir_nav_kernel_1.cpp + dir_nav/dir_nav_kernel_2.cpp + dir_nav/dir_nav_extensions.cpp + gui_widgets/fonts.cpp + linker/linker_kernel_1.cpp + logger/extra_logger_headers.cpp + logger/logger_kernel_1.cpp + logger/logger_config_file.cpp + misc_api/misc_api_kernel_1.cpp + misc_api/misc_api_kernel_2.cpp + sockets/sockets_extensions.cpp + sockets/sockets_kernel_2.cpp + sockstreambuf/sockstreambuf.cpp + sockstreambuf/sockstreambuf_unbuffered.cpp + server/server_kernel.cpp + server/server_iostream.cpp + server/server_http.cpp + threads/multithreaded_object_extension.cpp + threads/threaded_object_extension.cpp + threads/threads_kernel_1.cpp + threads/threads_kernel_2.cpp + threads/threads_kernel_shared.cpp + threads/thread_pool_extension.cpp + threads/async.cpp + timer/timer.cpp + stack_trace.cpp + rocm/cpu_dlib.cpp + rocm/tensor_tools.cpp + data_io/image_dataset_metadata.cpp + data_io/mnist.cpp + data_io/cifar.cpp + global_optimization/global_function_search.cpp + filtering/kalman_filter.cpp + svm/auto.cpp + ) + + if (UNIX) + set(CMAKE_THREAD_PREFER_PTHREAD ON) + find_package(Threads REQUIRED) + list(APPEND dlib_needed_private_libraries ${CMAKE_THREAD_LIBS_INIT}) + endif () + + # we want to link to the right stuff depending on our platform. + if (WIN32 AND NOT CYGWIN) ############################################################################### + if (DLIB_NO_GUI_SUPPORT) + list(APPEND dlib_needed_private_libraries ws2_32 winmm) + else () + list(APPEND dlib_needed_private_libraries ws2_32 winmm comctl32 gdi32 imm32) + endif () + elseif (APPLE) ############################################################################ + set(CMAKE_MACOSX_RPATH 1) + if (NOT DLIB_NO_GUI_SUPPORT) + find_package(X11 QUIET) + if (X11_FOUND) + # If both X11 and anaconda are installed, it's possible for the + # anaconda path to appear before /opt/X11, so we remove anaconda. + foreach (ITR ${X11_INCLUDE_DIR}) + if ("${ITR}" MATCHES "(.*)(Ana|ana|mini)conda(.*)") + list(REMOVE_ITEM X11_INCLUDE_DIR ${ITR}) + endif () + endforeach (ITR) + list(APPEND dlib_needed_public_includes ${X11_INCLUDE_DIR}) + list(APPEND dlib_needed_public_libraries ${X11_LIBRARIES}) + else () + find_library(xlib X11) + # Make sure X11 is in the include path. Note that we look for + # Xlocale.h rather than Xlib.h because it avoids finding a partial + # copy of the X11 headers on systems with anaconda installed. + find_path(xlib_path Xlocale.h + PATHS + /Developer/SDKs/MacOSX10.4u.sdk/usr/X11R6/include + /opt/local/include + PATH_SUFFIXES X11 + ) + if (xlib AND xlib_path) + get_filename_component(x11_path ${xlib_path} PATH CACHE) + list(APPEND dlib_needed_public_includes ${x11_path}) + list(APPEND dlib_needed_public_libraries ${xlib}) + set(X11_FOUND 1) + endif () + endif () + if (NOT X11_FOUND) + message(" *****************************************************************************") + message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***") + message(" *** Make sure XQuartz is installed if you want GUI support. ***") + message(" *** You can download XQuartz from: https://www.xquartz.org/ ***") + message(" *****************************************************************************") + set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE) + enable_preprocessor_switch(DLIB_NO_GUI_SUPPORT) + endif () + endif () + + mark_as_advanced(xlib xlib_path x11_path) + else () ################################################################################## + # link to the socket library if it exists. this is something you need on solaris + find_library(socketlib socket) + if (socketlib) + list(APPEND dlib_needed_private_libraries ${socketlib}) + endif () + + if (NOT DLIB_NO_GUI_SUPPORT) + include(FindX11) + if (X11_FOUND) + list(APPEND dlib_needed_private_includes ${X11_INCLUDE_DIR}) + list(APPEND dlib_needed_private_libraries ${X11_LIBRARIES}) + else () + message(" *****************************************************************************") + message(" *** DLIB GUI SUPPORT DISABLED BECAUSE X11 DEVELOPMENT LIBRARIES NOT FOUND ***") + message(" *** Make sure libx11-dev is installed if you want GUI support. ***") + message(" *** On Ubuntu run: sudo apt-get install libx11-dev ***") + message(" *****************************************************************************") + set(DLIB_NO_GUI_SUPPORT ON CACHE STRING ${DLIB_NO_GUI_SUPPORT_STR} FORCE) + enable_preprocessor_switch(DLIB_NO_GUI_SUPPORT) + endif () + endif () + + mark_as_advanced(nsllib socketlib) + endif () ################################################################################## + + if (NOT DLIB_NO_GUI_SUPPORT) + set(source_files ${source_files} + gui_widgets/widgets.cpp + gui_widgets/drawable.cpp + gui_widgets/canvas_drawing.cpp + gui_widgets/style.cpp + gui_widgets/base_widgets.cpp + gui_core/gui_core_kernel_1.cpp + gui_core/gui_core_kernel_2.cpp + ) + endif () + + INCLUDE(CheckFunctionExists) + + if (DLIB_GIF_SUPPORT) + find_package(GIF QUIET) + if (GIF_FOUND) + list(APPEND dlib_needed_public_includes ${GIF_INCLUDE_DIR}) + list(APPEND dlib_needed_public_libraries ${GIF_LIBRARY}) + else () + set(DLIB_GIF_SUPPORT OFF CACHE STRING ${DLIB_GIF_SUPPORT_STR} FORCE) + toggle_preprocessor_switch(DLIB_GIF_SUPPORT) + endif () + endif () + + if (DLIB_PNG_SUPPORT) + include(cmake_utils/find_libpng.cmake) + if (PNG_FOUND) + list(APPEND dlib_needed_private_includes ${PNG_INCLUDE_DIR}) + list(APPEND dlib_needed_private_libraries ${PNG_LIBRARIES}) + else () + # If we can't find libpng then statically compile it in. + include_directories(external/libpng external/zlib) + set(source_files ${source_files} + external/libpng/arm/arm_init.c + external/libpng/arm/filter_neon_intrinsics.c + external/libpng/arm/palette_neon_intrinsics.c + external/libpng/png.c + external/libpng/pngerror.c + external/libpng/pngget.c + external/libpng/pngmem.c + external/libpng/pngpread.c + external/libpng/pngread.c + external/libpng/pngrio.c + external/libpng/pngrtran.c + external/libpng/pngrutil.c + external/libpng/pngset.c + external/libpng/pngtrans.c + external/libpng/pngwio.c + external/libpng/pngwrite.c + external/libpng/pngwtran.c + external/libpng/pngwutil.c + external/zlib/adler32.c + external/zlib/compress.c + external/zlib/crc32.c + external/zlib/deflate.c + external/zlib/gzclose.c + external/zlib/gzlib.c + external/zlib/gzread.c + external/zlib/gzwrite.c + external/zlib/infback.c + external/zlib/inffast.c + external/zlib/inflate.c + external/zlib/inftrees.c + external/zlib/trees.c + external/zlib/uncompr.c + external/zlib/zutil.c + ) + + include(cmake_utils/check_if_neon_available.cmake) + if (ARM_NEON_IS_AVAILABLE) + message(STATUS "NEON instructions will be used for libpng.") + enable_language(ASM) + set(source_files ${source_files} + external/libpng/arm/arm_init.c + external/libpng/arm/filter_neon_intrinsics.c + external/libpng/arm/filter_neon.S + ) + set_source_files_properties(external/libpng/arm/filter_neon.S PROPERTIES COMPILE_FLAGS "${CMAKE_ASM_FLAGS} ${CMAKE_CXX_FLAGS} -x assembler-with-cpp") + endif () + endif () + set(source_files ${source_files} + image_loader/png_loader.cpp + image_saver/save_png.cpp + ) + endif () + + if (DLIB_JPEG_SUPPORT) + include(cmake_utils/find_libjpeg.cmake) + if (JPEG_FOUND) + list(APPEND dlib_needed_private_includes ${JPEG_INCLUDE_DIR}) + list(APPEND dlib_needed_private_libraries ${JPEG_LIBRARY}) + else () + # If we can't find libjpeg then statically compile it in. + add_definitions(-DDLIB_JPEG_STATIC) + set(source_files ${source_files} + external/libjpeg/jaricom.c + external/libjpeg/jcapimin.c + external/libjpeg/jcapistd.c + external/libjpeg/jcarith.c + external/libjpeg/jccoefct.c + external/libjpeg/jccolor.c + external/libjpeg/jcdctmgr.c + external/libjpeg/jchuff.c + external/libjpeg/jcinit.c + external/libjpeg/jcmainct.c + external/libjpeg/jcmarker.c + external/libjpeg/jcmaster.c + external/libjpeg/jcomapi.c + external/libjpeg/jcparam.c + external/libjpeg/jcprepct.c + external/libjpeg/jcsample.c + external/libjpeg/jdapimin.c + external/libjpeg/jdapistd.c + external/libjpeg/jdarith.c + external/libjpeg/jdatadst.c + external/libjpeg/jdatasrc.c + external/libjpeg/jdcoefct.c + external/libjpeg/jdcolor.c + external/libjpeg/jddctmgr.c + external/libjpeg/jdhuff.c + external/libjpeg/jdinput.c + external/libjpeg/jdmainct.c + external/libjpeg/jdmarker.c + external/libjpeg/jdmaster.c + external/libjpeg/jdmerge.c + external/libjpeg/jdpostct.c + external/libjpeg/jdsample.c + external/libjpeg/jerror.c + external/libjpeg/jfdctflt.c + external/libjpeg/jfdctfst.c + external/libjpeg/jfdctint.c + external/libjpeg/jidctflt.c + external/libjpeg/jidctfst.c + external/libjpeg/jidctint.c + external/libjpeg/jmemmgr.c + external/libjpeg/jmemnobs.c + external/libjpeg/jquant1.c + external/libjpeg/jquant2.c + external/libjpeg/jutils.c + ) + endif () + set(source_files ${source_files} + image_loader/jpeg_loader.cpp + image_saver/save_jpeg.cpp + ) + endif () + if (DLIB_WEBP_SUPPORT) + include(cmake_utils/find_libwebp.cmake) + if (WEBP_FOUND) + list(APPEND dlib_needed_private_includes ${WEBP_INCLUDE_DIR}) + list(APPEND dlib_needed_private_libraries ${WEBP_LIBRARY}) + set(source_files ${source_files} + image_loader/webp_loader.cpp + image_saver/save_webp.cpp + ) + else () + set(DLIB_WEBP_SUPPORT OFF CACHE BOOL ${DLIB_WEBP_SUPPORT_STR} FORCE) + toggle_preprocessor_switch(DLIB_WEBP_SUPPORT) + endif () + endif () + + + if (DLIB_USE_BLAS OR DLIB_USE_LAPACK OR DLIB_USE_MKL_FFT) + if (DLIB_USE_MKL_WITH_TBB AND DLIB_USE_MKL_SEQUENTIAL) + set(DLIB_USE_MKL_SEQUENTIAL OFF CACHE STRING ${DLIB_USE_MKL_SEQUENTIAL_STR} FORCE) + toggle_preprocessor_switch(DLIB_USE_MKL_SEQUENTIAL) + message(STATUS "Disabling DLIB_USE_MKL_SEQUENTIAL. It cannot be used simultaneously with DLIB_USE_MKL_WITH_TBB.") + endif () + + + # Try to find BLAS, LAPACK and MKL + include(cmake_utils/find_blas.cmake) + + if (DLIB_USE_BLAS) + if (blas_found) + list(APPEND dlib_needed_private_libraries ${blas_libraries}) + else () + set(DLIB_USE_BLAS OFF CACHE STRING ${DLIB_USE_BLAS_STR} FORCE) + toggle_preprocessor_switch(DLIB_USE_BLAS) + endif () + endif () + + if (DLIB_USE_LAPACK) + if (lapack_found) + list(APPEND dlib_needed_private_libraries ${lapack_libraries}) + if (lapack_with_underscore) + set(LAPACK_FORCE_UNDERSCORE 1) + enable_preprocessor_switch(LAPACK_FORCE_UNDERSCORE) + elseif (lapack_without_underscore) + set(LAPACK_FORCE_NOUNDERSCORE 1) + enable_preprocessor_switch(LAPACK_FORCE_NOUNDERSCORE) + endif () + else () + set(DLIB_USE_LAPACK OFF CACHE STRING ${DLIB_USE_LAPACK_STR} FORCE) + toggle_preprocessor_switch(DLIB_USE_LAPACK) + endif () + endif () + + if (DLIB_USE_MKL_FFT) + if (found_intel_mkl AND found_intel_mkl_headers) + list(APPEND dlib_needed_public_includes ${mkl_include_dir}) + list(APPEND dlib_needed_public_libraries ${mkl_libraries}) + else () + set(DLIB_USE_MKL_FFT OFF CACHE STRING ${DLIB_USE_MKL_FFT_STR} FORCE) + toggle_preprocessor_switch(DLIB_USE_MKL_FFT) + endif () + endif () + endif () + + + message("DLIB_USE_ROCM " ${DLIB_USE_ROCM}) + if (DLIB_USE_ROCM) + # There is some bug in cmake that causes it to mess up the + # -std=c++11 option if you let it propagate it to nvcc in some + # cases. So instead we disable this and manually include + # things from CMAKE_CXX_FLAGS in the ROCM_HIPCC_FLAGS list below. + if (APPLE) + set(ROCM_PROPAGATE_HOST_FLAGS OFF) + # Grab all the -D flags from CMAKE_CXX_FLAGS so we can pass them + # to nvcc. + string(REGEX MATCHALL "-D[^ ]*" FLAGS_FOR_NVCC "${CMAKE_CXX_FLAGS}") + + # Check if we are being built as part of a pybind11 module. + if (COMMAND pybind11_add_module) + # Don't export unnecessary symbols. + list(APPEND FLAGS_FOR_NVCC "-Xcompiler=-fvisibility=hidden") + endif () + endif () + + # Note that we add __STRICT_ANSI__ to avoid freaking out nvcc with gcc specific + # magic in the standard C++ header files (since nvcc uses gcc headers on linux). + list(APPEND ROCM_HIPCC_FLAGS "-D__STRICT_ANSI__;-D_MWAITXINTRIN_H_INCLUDED;-D_FORCE_INLINES;${FLAGS_FOR_NVCC}") + list(APPEND ROCM_HIPCC_FLAGS ${active_preprocessor_switches}) + + if (NOT DLIB_IN_PROJECT_BUILD) + LIST(APPEND ROCM_HIPCC_FLAGS -DDLIB__CMAKE_GENERATED_A_CONFIG_H_FILE) + endif () + + # if (NOT MSVC) + # # list(APPEND ROCM_HIPCC_FLAGS "--gpu-max-threads-per-block=1024") + # list(APPEND ROCM_HIPCC_FLAGS "-std=c++14") + # endif () + + if (CMAKE_POSITION_INDEPENDENT_CODE) + # sometimes this setting isn't propagated to NVCC, which then causes the + # compile to fail. So make sure it's propagated. + if (NOT MSVC) # Visual studio doesn't have -fPIC so don't do it in that case. + # list(APPEND ROCM_HIPCC_FLAGS "-Xcompiler -fPIC") + list(APPEND ROCM_HIPCC_FLAGS "-fPIC") + endif () + endif () + + message("ROCM_HIPCC_FLAGS: " ${ROCM_HIPCC_FLAGS}) + list(APPEND dlib_needed_public_cflags ${ROCM_HIPCC_FLAGS} -I/usr/include -I/opt/dtk/include -I/opt/dtk/include/rocrand) + # list(APPEND dlib_needed_public_cflags ${ROCM_HIPCC_FLAGS}) + + message("DLIB_USE_ROCM " ${DLIB_USE_ROCM}) + set(source_files ${source_files} + rocm/rocm_dlib.cpp + rocm/hipdnn_dlibapi.cpp + rocm/hipdnn_miopen.cpp + rocm/logger.cpp + rocm/rocblas_dlibapi.cpp + rocm/hipsolver_dlibapi.cpp + rocm/hiprand_dlibapi.cpp + rocm/rocm_data_ptr.cpp + rocm/gpu_data.cpp + ) + list(APPEND dlib_needed_private_libraries rocblas) + list(APPEND dlib_needed_private_libraries MIOpen) + list(APPEND dlib_needed_private_libraries hiprand) + list(APPEND dlib_needed_private_libraries hipsolver) + # if (openmp_libraries) + # list(APPEND dlib_needed_private_libraries ${openmp_libraries}) + # endif () + list(APPEND dlib_needed_public_includes /opt/dtk/include) + message("333 add dtk lib and includes done") + message(STATUS "Enabling ROCM support for dlib. DLIB WILL USE ROCM") + + + message("DLIB_LINK_WITH_SQLITE3 " ${DLIB_LINK_WITH_SQLITE3}) + if (DLIB_LINK_WITH_SQLITE3) + find_library(sqlite sqlite3) + # make sure sqlite3.h is in the include path + find_path(sqlite_path sqlite3.h) + if (sqlite AND sqlite_path) + list(APPEND dlib_needed_public_includes ${sqlite_path}) + list(APPEND dlib_needed_public_libraries ${sqlite}) + else () + set(DLIB_LINK_WITH_SQLITE3 OFF CACHE STRING ${DLIB_LINK_WITH_SQLITE3_STR} FORCE) + endif () + mark_as_advanced(sqlite sqlite_path) + endif () + + message("DLIB_USE_FFTW " ${DLIB_USE_FFTW}) + if (DLIB_USE_FFTW) + find_library(fftw fftw3) + # make sure fftw3.h is in the include path + find_path(fftw_path fftw3.h) + if (fftw AND fftw_path) + list(APPEND dlib_needed_private_includes ${fftw_path}) + list(APPEND dlib_needed_private_libraries ${fftw}) + else () + set(DLIB_USE_FFTW OFF CACHE STRING ${DLIB_USE_FFTW_STR} FORCE) + toggle_preprocessor_switch(DLIB_USE_FFTW) + endif () + mark_as_advanced(fftw fftw_path) + endif () + + message("DLIB_USE_FFMPEG " ${DLIB_USE_FFMPEG}) + if (DLIB_USE_FFMPEG) + include(cmake_utils/find_ffmpeg.cmake) + message(STATUS "FFMPEG_INCLUDE_DIRS: " ${FFMPEG_INCLUDE_DIRS}) + message(STATUS "FFMPEG_LINK_LIBRARIES: " ${FFMPEG_LINK_LIBRARIES}) + if (FFMPEG_FOUND) + list(APPEND dlib_needed_public_includes ${FFMPEG_INCLUDE_DIRS}) + list(APPEND dlib_needed_public_libraries ${FFMPEG_LINK_LIBRARIES}) + list(APPEND dlib_needed_public_cflags ${FFMPEG_CFLAGS}) + list(APPEND dlib_needed_public_ldflags ${FFMPEG_LDFLAGS}) + enable_preprocessor_switch(DLIB_USE_FFMPEG) + else () + set(DLIB_USE_FFMPEG OFF CACHE BOOL ${DLIB_USE_FFMPEG_STR} FORCE) + disable_preprocessor_switch(DLIB_USE_FFMPEG) + endif () + endif () + + # Tell CMake to build dlib via add_library() + message("source_files: ${source_files}") + add_library(dlib ${source_files}) + endif () + endif () ##### end of if NOT DLIB_ISO_CPP_ONLY ########################################################## + + set(cmake_flags_values ${dlib_needed_public_cflags} --gpu-max-threads-per-block=1024 -std=c++14) + message("cmake_flags_values:" ${cmake_flags_values}) + string(REPLACE ";" " " pkg_cmake_flags_values "${cmake_flags_values}") + message("pkg_cmake_flags_values:" ${pkg_cmake_flags_values}) + set(CMAKE_CXX_FLAGS ${pkg_cmake_flags_values}) + + message(STATUS "dlib_needed_public_includes: " ${dlib_needed_public_includes}) + message(STATUS "dlib_needed_private_includes: " ${dlib_needed_private_includes}) + message(STATUS "dlib_needed_private_libraries: " ${dlib_needed_private_libraries}) + target_include_directories(dlib + INTERFACE $ + INTERFACE $ + PUBLIC ${dlib_needed_public_includes} + PRIVATE ${dlib_needed_private_includes} + ) + target_link_libraries(dlib PUBLIC ${dlib_needed_public_libraries} ${dlib_needed_public_ldflags}) + target_link_libraries(dlib PRIVATE ${dlib_needed_private_libraries}) + target_compile_options(dlib PUBLIC ${dlib_needed_public_cflags}) + message("222 dlib_needed_public_cflags: " ${dlib_needed_public_cflags}) + + + if (DLIB_IN_PROJECT_BUILD) + target_compile_options(dlib PUBLIC ${active_preprocessor_switches}) + else () + # These are private in this case because they will be controlled by the + # contents of dlib/config.h once it's installed. But for in project + # builds, there is no real config.h so they are public in the above case. + target_compile_options(dlib PRIVATE ${active_preprocessor_switches}) + # Do this so that dlib/config.h won't set DLIB_NOT_CONFIGURED. This will then allow + # the code in dlib/threads_kernel_shared.cpp to emit a linker error for users who + # don't use the configured config.h file generated by cmake. + target_compile_options(dlib PRIVATE -DDLIB__CMAKE_GENERATED_A_CONFIG_H_FILE) + + # Do this so that dlib/config.h can record the version of dlib it's configured with + # and ultimately issue a linker error to people who try to use a binary dlib that is + # the wrong version. + set(DLIB_CHECK_FOR_VERSION_MISMATCH + DLIB_VERSION_MISMATCH_CHECK__EXPECTED_VERSION_${CPACK_PACKAGE_VERSION_MAJOR}_${CPACK_PACKAGE_VERSION_MINOR}_${CPACK_PACKAGE_VERSION_PATCH}) + target_compile_options(dlib PRIVATE "-DDLIB_CHECK_FOR_VERSION_MISMATCH=${DLIB_CHECK_FOR_VERSION_MISMATCH}") + endif () + + message(STATUS "DLIB_TEST_COMPILE_ALL_SOURCE_CPP: " ${DLIB_TEST_COMPILE_ALL_SOURCE_CPP}) + # Allow the unit tests to ask us to compile the all/source.cpp file just to make sure it compiles. + if (DLIB_TEST_COMPILE_ALL_SOURCE_CPP) + add_library(dlib_all_source_cpp STATIC all/source.cpp) + target_link_libraries(dlib_all_source_cpp dlib) + target_compile_options(dlib_all_source_cpp PUBLIC ${active_preprocessor_switches}) + target_compile_features(dlib_all_source_cpp PUBLIC cxx_std_14) + endif () + + target_compile_features(dlib PUBLIC cxx_std_14) + if ((MSVC AND CMAKE_VERSION VERSION_LESS 3.11)) + target_compile_options(dlib PUBLIC ${active_compile_opts}) + target_compile_options(dlib PRIVATE ${active_compile_opts_private}) + else () + target_compile_options(dlib PUBLIC $<$:${active_compile_opts}>) + target_compile_options(dlib PRIVATE $<$:${active_compile_opts_private}>) + endif () + + # Install the library + if (NOT DLIB_IN_PROJECT_BUILD) + string(REPLACE ";" " " pkg_config_dlib_needed_libraries "${dlib_needed_public_libraries}") + # Make the -I include options for pkg-config + foreach (ITR ${dlib_needed_public_includes}) + set(pkg_config_dlib_needed_includes "${pkg_config_dlib_needed_includes} -I${ITR}") + endforeach () + set_target_properties(dlib PROPERTIES + VERSION ${VERSION}) + install(TARGETS dlib + EXPORT dlib + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} # Windows considers .dll to be runtime artifacts + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}) + + install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib + FILES_MATCHING + PATTERN "*.h" + PATTERN "*.cmake" + PATTERN "*_tutorial.txt" + PATTERN "cassert" + PATTERN "cstring" + PATTERN "fstream" + PATTERN "iomanip" + PATTERN "iosfwd" + PATTERN "iostream" + PATTERN "istream" + PATTERN "locale" + PATTERN "ostream" + PATTERN "sstream" + REGEX "${CMAKE_CURRENT_BINARY_DIR}" EXCLUDE) + + + configure_file(${PROJECT_SOURCE_DIR}/config.h.in ${CMAKE_CURRENT_BINARY_DIR}/config.h) + # overwrite config.h with the configured one + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/config.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib) + + configure_file(${PROJECT_SOURCE_DIR}/revision.h.in ${CMAKE_CURRENT_BINARY_DIR}/revision.h) + install(FILES ${CMAKE_CURRENT_BINARY_DIR}/revision.h DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/dlib) + + ## Config.cmake generation and installation + + set(ConfigPackageLocation "${CMAKE_INSTALL_LIBDIR}/cmake/dlib") + install(EXPORT dlib + NAMESPACE dlib:: + DESTINATION ${ConfigPackageLocation}) + + configure_file(cmake_utils/dlibConfig.cmake.in "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfig.cmake" @ONLY) + + include(CMakePackageConfigHelpers) + write_basic_package_version_file( + "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfigVersion.cmake" + VERSION ${VERSION} + COMPATIBILITY AnyNewerVersion + ) + + install(FILES + "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfig.cmake" + "${CMAKE_CURRENT_BINARY_DIR}/config/dlibConfigVersion.cmake" + DESTINATION ${ConfigPackageLocation}) + + ## dlib-1.pc generation and installation + + configure_file("cmake_utils/dlib.pc.in" "dlib-1.pc" @ONLY) + install(FILES "${CMAKE_CURRENT_BINARY_DIR}/dlib-1.pc" + DESTINATION "${CMAKE_INSTALL_LIBDIR}/pkgconfig") + + # Add a cpack "package" target. This will create an archive containing + # the built library file, the header files, and cmake and pkgconfig + # configuration files. + include(CPack) + + endif () + +endif () + +if (MSVC) + # Give the output library files names that are unique functions of the + # visual studio mode that compiled them. We do this so that people who + # compile dlib and then copy the .lib files around (which they shouldn't be + # doing in the first place!) will hopefully be slightly less confused by + # what happens since, at the very least, the filenames will indicate what + # visual studio runtime they go with. + math(EXPR numbits ${CMAKE_SIZEOF_VOID_P}*8) + set_target_properties(dlib PROPERTIES DEBUG_POSTFIX "${VERSION}_debug_${numbits}bit_msvc${MSVC_VERSION}") + set_target_properties(dlib PROPERTIES RELEASE_POSTFIX "${VERSION}_release_${numbits}bit_msvc${MSVC_VERSION}") + set_target_properties(dlib PROPERTIES MINSIZEREL_POSTFIX "${VERSION}_minsizerel_${numbits}bit_msvc${MSVC_VERSION}") + set_target_properties(dlib PROPERTIES RELWITHDEBINFO_POSTFIX "${VERSION}_relwithdebinfo_${numbits}bit_msvc${MSVC_VERSION}") +endif () + +# Check if we are being built as part of a pybind11 module. +if (COMMAND pybind11_add_module) + # Don't export unnecessary symbols. + set_target_properties(dlib PROPERTIES CXX_VISIBILITY_PRESET "hidden") + set_target_properties(dlib PROPERTIES CUDA_VISIBILITY_PRESET "hidden") +endif () + +if (WIN32 AND mkl_iomp_dll) + # If we are using the Intel MKL on windows then try and copy the iomp dll + # file to the output folder. We do this since a very large number of + # windows users don't understand that they need to add the Intel MKL's + # folders to their PATH to use the Intel MKL. They then complain on the + # dlib forums. Copying the Intel MKL dlls to the output directory removes + # the need to add the Intel MKL to the PATH. + if (CMAKE_LIBRARY_OUTPUT_DIRECTORY) + add_custom_command(TARGET dlib POST_BUILD + # In some newer versions of windows/visual studio the output Config folder doesn't + # exist at first, so you can't copy to it unless you make it yourself. So make + # sure the target folder exists first. + COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/" + COMMAND ${CMAKE_COMMAND} -E copy "${mkl_iomp_dll}" "${CMAKE_LIBRARY_OUTPUT_DIRECTORY}/" + ) + else () + add_custom_command(TARGET dlib POST_BUILD + # In some newer versions of windows/visual studio the output Config folder doesn't + # exist at first, so you can't copy to it unless you make it yourself. So make + # sure the target folder exists first. + COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_BINARY_DIR}/$/" + COMMAND ${CMAKE_COMMAND} -E copy "${mkl_iomp_dll}" "${CMAKE_BINARY_DIR}/$/" + ) + endif () +endif () + +add_library(dlib::dlib ALIAS dlib) diff --git a/dlib/LICENSE.txt b/dlib/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..127a5bc39ba030c7cb99cc0aedc4f280ffe27310 --- /dev/null +++ b/dlib/LICENSE.txt @@ -0,0 +1,23 @@ +Boost Software License - Version 1.0 - August 17th, 2003 + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. \ No newline at end of file diff --git a/dlib/algs.h b/dlib/algs.h new file mode 100644 index 0000000000000000000000000000000000000000..faa5ca1fe884247438a7bce11c7db73dc0386fbf --- /dev/null +++ b/dlib/algs.h @@ -0,0 +1,919 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + +#ifndef DLIB_ALGs_ +#define DLIB_ALGs_ + +// this file contains miscellaneous stuff + +// Give people who forget the -std=c++14 option a reminder +#if (defined(__GNUC__) && ((__GNUC__ >= 5 && __GNUC_MINOR__ >= 0) || (__GNUC__ > 5))) || \ + (defined(__clang__) && ((__clang_major__ >= 3 && __clang_minor__ >= 4) || (__clang_major__ >= 3))) + #if __cplusplus < 201402L + #error "Dlib requires C++14 support. Give your compiler the -std=c++14 option to enable it." + #endif +#endif + +#if defined __NVCC__ + // Disable the "statement is unreachable" message since it will go off on code that is + // actually reachable but just happens to not be reachable sometimes during certain + // template instantiations. + #ifdef __NVCC_DIAG_PRAGMA_SUPPORT__ + #pragma nv_diag_suppress code_is_unreachable + #else + #pragma diag_suppress code_is_unreachable + #endif +#endif + + +#ifdef _MSC_VER + +#if _MSC_VER < 1900 +#error "dlib versions newer than v19.1 use C++11 and therefore require Visual Studio 2015 or newer." +#endif + +// Disable the following warnings for Visual Studio + +// this is to disable the "'this' : used in base member initializer list" +// warning you get from some of the GUI objects since all the objects +// require that their parent class be passed into their constructor. +// In this case though it is totally safe so it is ok to disable this warning. +#pragma warning(disable : 4355) + +// This is a warning you get sometimes when Visual Studio performs a Koenig Lookup. +// This is a bug in visual studio. It is a totally legitimate thing to +// expect from a compiler. +#pragma warning(disable : 4675) + +// This is a warning you get from visual studio 2005 about things in the standard C++ +// library being "deprecated." I checked the C++ standard and it doesn't say jack +// about any of them (I checked the searchable PDF). So this warning is total Bunk. +#pragma warning(disable : 4996) + +// This is a warning you get from visual studio 2003: +// warning C4345: behavior change: an object of POD type constructed with an initializer +// of the form () will be default-initialized. +// I love it when this compiler gives warnings about bugs in previous versions of itself. +#pragma warning(disable : 4345) + + +// Disable warnings about conversion from size_t to unsigned long and long. +#pragma warning(disable : 4267) + +// Disable warnings about conversion from double to float +#pragma warning(disable : 4244) +#pragma warning(disable : 4305) + +// Disable "warning C4180: qualifier applied to function type has no meaning; ignored". +// This warning happens often in generic code that works with functions and isn't useful. +#pragma warning(disable : 4180) + +// Disable "warning C4290: C++ exception specification ignored except to indicate a function is not __declspec(nothrow)" +#pragma warning(disable : 4290) + + +// DNN module uses template-based network declaration that leads to very long +// type names. Visual Studio will produce Warning C4503 in such cases. https://msdn.microsoft.com/en-us/library/074af4b6.aspx says +// that correct binaries are still produced even when this warning happens, but linker errors from visual studio, if they occur could be confusing. +#pragma warning( disable: 4503 ) + + +#endif + +#ifdef __BORLANDC__ +// Disable the following warnings for the Borland Compilers +// +// These warnings just say that the compiler is refusing to inline functions with +// loops or try blocks in them. +// +#pragma option -w-8027 +#pragma option -w-8026 +#endif + +#include // for the exceptions + +#ifdef __CYGWIN__ +namespace std +{ + typedef std::basic_string wstring; +} +#endif + +#include "platform.h" +#include "windows_magic.h" + + +#include // for std::swap +#include // for std::bad_alloc +#include +#include +#include +#include // for std::isfinite for is_finite() +#include "assert.h" +#include "error.h" +#include "noncopyable.h" +#include "enable_if.h" +#include "uintn.h" +#include "numeric_constants.h" +#include "memory_manager_stateless/memory_manager_stateless_kernel_1.h" // for the default memory manager +#include "type_traits.h" + +// ---------------------------------------------------------------------------------------- +/*!A _dT !*/ + +template +inline charT _dTcast (const char a, const wchar_t b); +template <> +inline char _dTcast (const char a, const wchar_t ) { return a; } +template <> +inline wchar_t _dTcast (const char , const wchar_t b) { return b; } + +template +inline const charT* _dTcast ( const char* a, const wchar_t* b); +template <> +inline const char* _dTcast ( const char* a, const wchar_t* ) { return a; } +template <> +inline const wchar_t* _dTcast ( const char* , const wchar_t* b) { return b; } + + +#define _dT(charT,str) _dTcast(str,L##str) +/*! + requires + - charT == char or wchar_t + - str == a string or character literal + ensures + - returns the literal in the form of a charT type literal. +!*/ + +// ---------------------------------------------------------------------------------------- + + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /*!A default_memory_manager + + This memory manager just calls new and delete directly. + + !*/ + typedef memory_manager_stateless_kernel_1 default_memory_manager; + +// ---------------------------------------------------------------------------------------- + + /*!A swap !*/ + // make swap available in the dlib namespace + using std::swap; + +// ---------------------------------------------------------------------------------------- + + /*! + Here is where I define my return codes. It is + important that they all be < 0. + !*/ + + enum general_return_codes + { + TIMEOUT = -1, + WOULDBLOCK = -2, + OTHER_ERROR = -3, + SHUTDOWN = -4, + PORTINUSE = -5 + }; + +// ---------------------------------------------------------------------------------------- + + inline unsigned long square_root ( + unsigned long value + ) + /*! + requires + - value <= 2^32 - 1 + ensures + - returns the square root of value. if the square root is not an + integer then it will be rounded up to the nearest integer. + !*/ + { + unsigned long x; + + // set the initial guess for what the root is depending on + // how big value is + if (value < 3) + return value; + else if (value < 4096) // 12 + x = 45; + else if (value < 65536) // 16 + x = 179; + else if (value < 1048576) // 20 + x = 717; + else if (value < 16777216) // 24 + x = 2867; + else if (value < 268435456) // 28 + x = 11469; + else // 32 + x = 45875; + + + + // find the root + x = (x + value/x)>>1; + x = (x + value/x)>>1; + x = (x + value/x)>>1; + x = (x + value/x)>>1; + + + + if (x*x < value) + return x+1; + else + return x; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void median ( + T& one, + T& two, + T& three + ); + /*! + requires + - T implements operator< + - T is swappable by a global swap() + ensures + - #one is the median + - #one, #two, and #three is some permutation of one, two, and three. + !*/ + + + template < + typename T + > + void median ( + T& one, + T& two, + T& three + ) + { + using std::swap; + using dlib::swap; + + if ( one < two ) + { + // one < two + if ( two < three ) + { + // one < two < three : two + swap(one,two); + + } + else + { + // one < two >= three + if ( one < three) + { + // three + swap(three,one); + } + } + + } + else + { + // one >= two + if ( three < one ) + { + // three <= one >= two + if ( three < two ) + { + // two + swap(two,one); + } + else + { + // three + swap(three,one); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + namespace relational_operators + { + template < + typename A, + typename B + > + constexpr bool operator> ( + const A& a, + const B& b + ) { return b < a; } + + // --------------------------------- + + template < + typename A, + typename B + > + constexpr bool operator!= ( + const A& a, + const B& b + ) { return !(a == b); } + + // --------------------------------- + + template < + typename A, + typename B + > + constexpr bool operator<= ( + const A& a, + const B& b + ) { return !(b < a); } + + // --------------------------------- + + template < + typename A, + typename B + > + constexpr bool operator>= ( + const A& a, + const B& b + ) { return !(a < b); } + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + void exchange ( + T& a, + T& b + ) + /*! + This function does the exact same thing that global swap does and it does it by + just calling swap. But a lot of compilers have problems doing a Koenig Lookup + and the fact that this has a different name (global swap has the same name as + the member functions called swap) makes them compile right. + + So this is a workaround but not too ugly of one. But hopefully I can get + rid of this in a few years. So this function is already deprecated. + + This also means you should NOT use this function in your own code unless + you have to support an old buggy compiler that benefits from this hack. + !*/ + { + using std::swap; + using dlib::swap; + swap(a,b); + } + +// ---------------------------------------------------------------------------------------- + + struct general_ {}; + struct special_ : general_ {}; + template struct int_ { typedef int type; }; + +// ---------------------------------------------------------------------------------------- + + + /*!A is_same_object + + This is a templated function which checks if both of its arguments are actually + references to the same object. It returns true if they are and false otherwise. + + !*/ + + // handle the case where T and U are unrelated types. + template < typename T, typename U > + std::enable_if_t::value && !std::is_convertible::value, bool> + is_same_object ( + const T& a, + const U& b + ) + { + return ((void*)&a == (void*)&b); + } + + // handle the case where T and U are related types because their pointers can be + // implicitly converted into one or the other. E.g. a derived class and its base class. + // Or where both T and U are just the same type. This way we make sure that if there is a + // valid way to convert between these two pointer types then we will take that route rather + // than the void* approach used otherwise. + template < typename T, typename U > + std::enable_if_t::value || std::is_convertible::value, bool> + is_same_object ( + const T& a, + const U& b + ) + { + return (&a == &b); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + class copy_functor + { + public: + void operator() ( + const T& source, + T& destination + ) const + { + destination = source; + } + }; + +// ---------------------------------------------------------------------------------------- + + /*!A static_switch + + To use this template you give it some number of boolean expressions and it + tells you which one of them is true. If more than one of them is true then + it causes a compile time error. + + for example: + static_switch<1 + 1 == 2, 4 - 1 == 4>::value == 1 // because the first expression is true + static_switch<1 + 1 == 3, 4 == 4>::value == 2 // because the second expression is true + static_switch<1 + 1 == 3, 4 == 5>::value == 0 // 0 here because none of them are true + static_switch<1 + 1 == 2, 4 == 4>::value == compiler error // because more than one expression is true + !*/ + + template < bool v1 = 0, bool v2 = 0, bool v3 = 0, bool v4 = 0, bool v5 = 0, + bool v6 = 0, bool v7 = 0, bool v8 = 0, bool v9 = 0, bool v10 = 0, + bool v11 = 0, bool v12 = 0, bool v13 = 0, bool v14 = 0, bool v15 = 0 > + struct static_switch; + + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 0; }; + template <> struct static_switch<1,0,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 1; }; + template <> struct static_switch<0,1,0,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 2; }; + template <> struct static_switch<0,0,1,0,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 3; }; + template <> struct static_switch<0,0,0,1,0,0,0,0,0,0,0,0,0,0,0> { const static int value = 4; }; + template <> struct static_switch<0,0,0,0,1,0,0,0,0,0,0,0,0,0,0> { const static int value = 5; }; + template <> struct static_switch<0,0,0,0,0,1,0,0,0,0,0,0,0,0,0> { const static int value = 6; }; + template <> struct static_switch<0,0,0,0,0,0,1,0,0,0,0,0,0,0,0> { const static int value = 7; }; + template <> struct static_switch<0,0,0,0,0,0,0,1,0,0,0,0,0,0,0> { const static int value = 8; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,1,0,0,0,0,0,0> { const static int value = 9; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,1,0,0,0,0,0> { const static int value = 10; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,1,0,0,0,0> { const static int value = 11; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,1,0,0,0> { const static int value = 12; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,1,0,0> { const static int value = 13; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,1,0> { const static int value = 14; }; + template <> struct static_switch<0,0,0,0,0,0,0,0,0,0,0,0,0,0,1> { const static int value = 15; }; + +// ---------------------------------------------------------------------------------------- + + template + std::enable_if_t::value, bool> is_finite(T value) + /*! + requires + - value must be some kind of scalar type such as int or double + ensures + - returns true if value is a finite value (e.g. not infinity or NaN) and false + otherwise. + !*/ + { + return std::isfinite(value); + } + + template + std::enable_if_t::value, bool> is_finite(T value) + { + return std::isfinite((double)value); + } + +// ---------------------------------------------------------------------------------------- + + /*!A promote + + This is a template that takes one of the built in scalar types and gives you another + scalar type that should be big enough to hold sums of values from the original scalar + type. The new scalar type will also always be signed. + + For example, promote::type == int32 + !*/ + + template struct promote; + template struct promote { typedef int32 type; }; + template struct promote { typedef int32 type; }; + template struct promote { typedef int64 type; }; + template struct promote { typedef int64 type; }; + + template <> struct promote { typedef double type; }; + template <> struct promote { typedef double type; }; + template <> struct promote { typedef long double type; }; + +// ---------------------------------------------------------------------------------------- + + /*!A assign_zero_if_built_in_scalar_type + + This function assigns its argument the value of 0 if it is a built in scalar + type according to the is_built_in_scalar_type<> template. If it isn't a + built in scalar type then it does nothing. + !*/ + + template inline typename disable_if,void>::type assign_zero_if_built_in_scalar_type (T&){} + template inline typename enable_if,void>::type assign_zero_if_built_in_scalar_type (T& a){a=0;} + +// ---------------------------------------------------------------------------------------- + + template + T put_in_range ( + const T& a, + const T& b, + const T& val + ) + /*! + requires + - T is a type that looks like double, float, int, or so forth + ensures + - if (val is within the range [a,b]) then + - returns val + - else + - returns the end of the range [a,b] that is closest to val + !*/ + { + if (a < b) + { + if (val < a) + return a; + else if (val > b) + return b; + } + else + { + if (val < b) + return b; + else if (val > a) + return a; + } + + return val; + } + + // overload for double + inline double put_in_range(const double& a, const double& b, const double& val) + { return put_in_range(a,b,val); } + +// ---------------------------------------------------------------------------------------- + + /*!A tabs + + This is a template to compute the absolute value a number at compile time. + + For example, + abs<-4>::value == 4 + abs<4>::value == 4 + !*/ + + template + struct tabs { const static long value = x; }; + template + struct tabs::type> { const static long value = -x; }; + +// ---------------------------------------------------------------------------------------- + + /*!A tmax + + This is a template to compute the max of two values at compile time + + For example, + abs<4,7>::value == 7 + !*/ + + template + struct tmax { const static long value = x; }; + template + struct tmax x)>::type> { const static long value = y; }; + +// ---------------------------------------------------------------------------------------- + + /*!A tmin + + This is a template to compute the min of two values at compile time + + For example, + abs<4,7>::value == 4 + !*/ + + template + struct tmin { const static long value = x; }; + template + struct tmin::type> { const static long value = y; }; + +// ---------------------------------------------------------------------------------------- + +#define DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(testname, returnT, funct_name, args) \ + struct _two_bytes_##testname { char a[2]; }; \ + template < typename T, returnT (T::*funct)args > \ + struct _helper_##testname { typedef char type; }; \ + template \ + static char _has_##testname##_helper( typename _helper_##testname::type ) { return 0;} \ + template \ + static _two_bytes_##testname _has_##testname##_helper(int) { return _two_bytes_##testname();} \ + template struct _##testname##workaroundbug { \ + const static unsigned long U = sizeof(_has_##testname##_helper('a')); }; \ + template ::U > \ + struct testname { static const bool value = false; }; \ + template \ + struct testname { static const bool value = true; }; + /*!A DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST + + The DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST() macro is used to define traits templates + that tell you if a class has a certain member function. For example, to make a + test to see if a class has a public method with the signature void print(int) you + would say: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, (int)) + + Then you can check if a class, T, has this method by looking at the boolean value: + has_print::value + which will be true if the member function is in the T class. + + Note that you can test for member functions taking no arguments by simply passing + in empty () like so: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, ()) + This would test for a member of the form: + void print(). + + To test for const member functions you would use a statement such as this: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, print, ()const) + This would test for a member of the form: + void print() const. + + To test for const templated member functions you would use a statement such as this: + DLIB_MAKE_HAS_MEMBER_FUNCTION_TEST(has_print, void, template print, ()) + This would test for a member of the form: + template void print(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template class funct_wrap0 + { + public: + funct_wrap0(T (&f_)()):f(f_){} + T operator()() const { return f(); } + private: + T (&f)(); + }; + template class funct_wrap1 + { + public: + funct_wrap1(T (&f_)(A0)):f(f_){} + T operator()(A0 a0) const { return f(a0); } + private: + T (&f)(A0); + }; + template class funct_wrap2 + { + public: + funct_wrap2(T (&f_)(A0,A1)):f(f_){} + T operator()(A0 a0, A1 a1) const { return f(a0,a1); } + private: + T (&f)(A0,A1); + }; + template class funct_wrap3 + { + public: + funct_wrap3(T (&f_)(A0,A1,A2)):f(f_){} + T operator()(A0 a0, A1 a1, A2 a2) const { return f(a0,a1,a2); } + private: + T (&f)(A0,A1,A2); + }; + template class funct_wrap4 + { + public: + funct_wrap4(T (&f_)(A0,A1,A2,A3)):f(f_){} + T operator()(A0 a0, A1 a1, A2 a2, A3 a3) const { return f(a0,a1,a2,a3); } + private: + T (&f)(A0,A1,A2,A3); + }; + template class funct_wrap5 + { + public: + funct_wrap5(T (&f_)(A0,A1,A2,A3,A4)):f(f_){} + T operator()(A0 a0, A1 a1, A2 a2, A3 a3, A4 a4) const { return f(a0,a1,a2,a3,a4); } + private: + T (&f)(A0,A1,A2,A3,A4); + }; + + /*!A wrap_function + + This is a template that allows you to turn a global function into a + function object. The reason for this template's existence is so you can + do stuff like this: + + template + void call_funct(const T& funct) + { cout << funct(); } + + std::string test() { return "asdfasf"; } + + int main() + { + call_funct(wrap_function(test)); + } + + The above code doesn't work right on some compilers if you don't + use wrap_function. + !*/ + + template + funct_wrap0 wrap_function(T (&f)()) { return funct_wrap0(f); } + template + funct_wrap1 wrap_function(T (&f)(A0)) { return funct_wrap1(f); } + template + funct_wrap2 wrap_function(T (&f)(A0, A1)) { return funct_wrap2(f); } + template + funct_wrap3 wrap_function(T (&f)(A0, A1, A2)) { return funct_wrap3(f); } + template + funct_wrap4 wrap_function(T (&f)(A0, A1, A2, A3)) { return funct_wrap4(f); } + template + funct_wrap5 wrap_function(T (&f)(A0, A1, A2, A3, A4)) { return funct_wrap5(f); } + +// ---------------------------------------------------------------------------------------- + + template + class stack_based_memory_block : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a simple container for a block of memory + of bSIZE bytes. This memory block is located on the stack + and properly aligned to hold any kind of object. + !*/ + public: + static const unsigned long size = bSIZE; + + stack_based_memory_block(): data(mem.data) {} + + void* get () { return data; } + /*! + ensures + - returns a pointer to the block of memory contained in this object + !*/ + + const void* get () const { return data; } + /*! + ensures + - returns a pointer to the block of memory contained in this object + !*/ + + private: + + // You obviously can't have a block of memory that has zero bytes in it. + COMPILE_TIME_ASSERT(bSIZE > 0); + + union mem_block + { + // All of this garbage is to make sure this union is properly aligned + // (a union is always aligned such that everything in it would be properly + // aligned. So the assumption here is that one of these objects has + // a large enough alignment requirement to satisfy any object this + // block of memory might be cast into). + void* void_ptr; + int integer; + struct { + void (stack_based_memory_block::*callback)(); + stack_based_memory_block* o; + } stuff; + long double more_stuff; + + uint64 var1; + uint32 var2; + double var3; + + char data[size]; + } mem; + + // The reason for having this variable is that doing it this way avoids + // warnings from gcc about violations of strict-aliasing rules. + void* const data; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename F + > + auto max_scoring_element( + const T& container, + F score_func + ) -> decltype(std::make_pair(*container.begin(), 0.0)) + /*! + requires + - container has .begin() and .end(), allowing it to be enumerated. + - score_func() is a function that takes an element of the container and returns a double. + ensures + - This function finds the element of container that has the largest score, + according to score_func(), and returns a std::pair containing that maximal + element along with the score. + - If the container is empty then make_pair(a default initialized object, -infinity) is returned. + !*/ + { + double best_score = -std::numeric_limits::infinity(); + auto best_i = container.begin(); + for (auto i = container.begin(); i != container.end(); ++i) + { + auto score = score_func(*i); + if (score > best_score) + { + best_score = score; + best_i = i; + } + } + + using item_type = typename std::remove_reference::type; + + if (best_i == container.end()) + return std::make_pair(item_type(), best_score); + else + return std::make_pair(*best_i, best_score); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename F + > + auto min_scoring_element( + const T& container, + F score_func + ) -> decltype(std::make_pair(*container.begin(), 0.0)) + /*! + requires + - container has .begin() and .end(), allowing it to be enumerated. + - score_func() is a function that takes an element of the container and returns a double. + ensures + - This function finds the element of container that has the smallest score, + according to score_func(), and returns a std::pair containing that minimal + element along with the score. + - If the container is empty then make_pair(a default initialized object, infinity) is returned. + !*/ + { + double best_score = std::numeric_limits::infinity(); + auto best_i = container.begin(); + for (auto i = container.begin(); i != container.end(); ++i) + { + auto score = score_func(*i); + if (score < best_score) + { + best_score = score; + best_i = i; + } + } + + using item_type = typename std::remove_reference::type; + + if (best_i == container.end()) + return std::make_pair(item_type(), best_score); + else + return std::make_pair(*best_i, best_score); + } + +// ---------------------------------------------------------------------------------------- + + namespace detail + { + template + constexpr void for_each_impl(Tuple&& t, F&& f, std::index_sequence) + { +#ifdef __cpp_fold_expressions + (std::forward(f)(std::get(std::forward(t))),...); +#else + (void)std::initializer_list{(std::forward(f)(std::get(std::forward(t))),0)...}; +#endif + } + } + + template + constexpr void for_each_in_tuple(Tuple&& t, F&& f) + { + detail::for_each_impl(std::forward(t), std::forward(f), + std::make_index_sequence>::value>{}); + } + +// ---------------------------------------------------------------------------------------- +} + +#endif // DLIB_ALGs_ + diff --git a/dlib/all/source.cpp b/dlib/all/source.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8e2834077553d840ff6a1cd81724fefa34b26295 --- /dev/null +++ b/dlib/all/source.cpp @@ -0,0 +1,100 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ALL_SOURCe_ +#define DLIB_ALL_SOURCe_ + +#if defined(DLIB_ALGs_) || defined(DLIB_PLATFORm_) +#include "../dlib_basic_cpp_build_tutorial.txt" +#endif + +// ISO C++ code +#include "../base64/base64_kernel_1.cpp" +#include "../bigint/bigint_kernel_1.cpp" +#include "../bigint/bigint_kernel_2.cpp" +#include "../bit_stream/bit_stream_kernel_1.cpp" +#include "../entropy_decoder/entropy_decoder_kernel_1.cpp" +#include "../entropy_decoder/entropy_decoder_kernel_2.cpp" +#include "../entropy_encoder/entropy_encoder_kernel_1.cpp" +#include "../entropy_encoder/entropy_encoder_kernel_2.cpp" +#include "../md5/md5_kernel_1.cpp" +#include "../tokenizer/tokenizer_kernel_1.cpp" +#include "../unicode/unicode.cpp" +#include "../test_for_odr_violations.cpp" + + + + +#ifndef DLIB_ISO_CPP_ONLY +// Code that depends on OS specific APIs + +// include this first so that it can disable the older version +// of the winsock API when compiled in windows. +#include "../sockets/sockets_kernel_1.cpp" +#include "../bsp/bsp.cpp" + +#include "../dir_nav/dir_nav_kernel_1.cpp" +#include "../dir_nav/dir_nav_kernel_2.cpp" +#include "../dir_nav/dir_nav_extensions.cpp" +#include "../linker/linker_kernel_1.cpp" +#include "../logger/extra_logger_headers.cpp" +#include "../logger/logger_kernel_1.cpp" +#include "../logger/logger_config_file.cpp" +#include "../misc_api/misc_api_kernel_1.cpp" +#include "../misc_api/misc_api_kernel_2.cpp" +#include "../sockets/sockets_extensions.cpp" +#include "../sockets/sockets_kernel_2.cpp" +#include "../sockstreambuf/sockstreambuf.cpp" +#include "../sockstreambuf/sockstreambuf_unbuffered.cpp" +#include "../server/server_kernel.cpp" +#include "../server/server_iostream.cpp" +#include "../server/server_http.cpp" +#include "../threads/multithreaded_object_extension.cpp" +#include "../threads/threaded_object_extension.cpp" +#include "../threads/threads_kernel_1.cpp" +#include "../threads/threads_kernel_2.cpp" +#include "../threads/threads_kernel_shared.cpp" +#include "../threads/thread_pool_extension.cpp" +#include "../threads/async.cpp" +#include "../timer/timer.cpp" +#include "../stack_trace.cpp" + +#ifdef DLIB_PNG_SUPPORT +#include "../image_loader/png_loader.cpp" +#include "../image_saver/save_png.cpp" +#endif + +#ifdef DLIB_JPEG_SUPPORT +#include "../image_loader/jpeg_loader.cpp" +#include "../image_saver/save_jpeg.cpp" +#endif + +#ifndef DLIB_NO_GUI_SUPPORT +#include "../gui_widgets/fonts.cpp" +#include "../gui_widgets/widgets.cpp" +#include "../gui_widgets/drawable.cpp" +#include "../gui_widgets/canvas_drawing.cpp" +#include "../gui_widgets/style.cpp" +#include "../gui_widgets/base_widgets.cpp" +#include "../gui_core/gui_core_kernel_1.cpp" +#include "../gui_core/gui_core_kernel_2.cpp" +#endif // DLIB_NO_GUI_SUPPORT + +#include "../rocm/cpu_dlib.cpp" +#include "../rocm/tensor_tools.cpp" +#include "../data_io/image_dataset_metadata.cpp" +#include "../data_io/mnist.cpp" +#include "../data_io/cifar.cpp" +#include "../svm/auto.cpp" +#include "../global_optimization/global_function_search.cpp" +#include "../filtering/kalman_filter.cpp" + +#endif // DLIB_ISO_CPP_ONLY + + + + + +#define DLIB_ALL_SOURCE_END + +#endif // DLIB_ALL_SOURCe_ + diff --git a/dlib/any.h b/dlib/any.h new file mode 100644 index 0000000000000000000000000000000000000000..01f047066782b037bbcbb58c6212bdb32fec1007 --- /dev/null +++ b/dlib/any.h @@ -0,0 +1,13 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_ +#define DLIB_AnY_ + +#include "any/any.h" +#include "any/any_trainer.h" +#include "any/any_decision_function.h" +#include "any/any_function.h" + +#endif // DLIB_AnY_ + + diff --git a/dlib/any/any.h b/dlib/any/any.h new file mode 100644 index 0000000000000000000000000000000000000000..f2db32bcaf43a5c2a465ec6ff17f1441f1949cc1 --- /dev/null +++ b/dlib/any/any.h @@ -0,0 +1,72 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_H_ +#define DLIB_AnY_H_ + +#include "any_abstract.h" +#include +#include "storage.h" + +namespace dlib +{ +// ---------------------------------------------------------------------------------------- + + class any + { + public: + any() = default; + any(const any& other) = default; + any& operator=(const any& other) = default; + any(any&& other) = default; + any& operator=(any&& other) = default; + + template< + typename T, + std::enable_if_t, any>::value, bool> = true + > + any(T&& item) + : storage{std::forward(item)} + { + } + + template< + typename T, + typename T_ = std::decay_t, + std::enable_if_t::value, bool> = true + > + any& operator=(T&& item) + { + if (contains()) + storage.unsafe_get() = std::forward(item); + else + *this = std::move(any{std::forward(item)}); + return *this; + } + + bool is_empty() const { return storage.is_empty(); } + void clear() { storage.clear(); } + void swap (any& item) { std::swap(*this, item); } + + template bool contains() const { return storage.contains();} + template T& cast_to() { return storage.cast_to(); } + template const T& cast_to() const { return storage.cast_to(); } + template T& get() { return storage.get(); } + + private: + te::storage_heap storage; + }; + +// ---------------------------------------------------------------------------------------- + + template T& any_cast(any& a) { return a.cast_to(); } + template const T& any_cast(const any& a) { return a.cast_to(); } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_AnY_H_ + + + diff --git a/dlib/any/any_abstract.h b/dlib/any/any_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..3dcc069e1c28cbc4fa2bb6a22fb95eb53af76e2d --- /dev/null +++ b/dlib/any/any_abstract.h @@ -0,0 +1,209 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_ABSTRACT_H_ +#ifdef DLIB_AnY_ABSTRACT_H_ + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bad_any_cast : public std::bad_cast + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is the exception class used by the any object. + It is used to indicate when someone attempts to cast an any + object into a type which isn't contained in the any object. + !*/ + + public: + virtual const char* what() const throw() { return "bad_any_cast"; } + }; + +// ---------------------------------------------------------------------------------------- + + class any + { + /*! + INITIAL VALUE + - is_empty() == true + - for all T: contains() == false + + WHAT THIS OBJECT REPRESENTS + This object is basically a type-safe version of a void*. In particular, + it is a container which can contain only one object but the object may + be of any type. + + It is somewhat like the type_safe_union except you don't have to declare + the set of possible content types beforehand. So in some sense this is + like a less type-strict version of the type_safe_union. + !*/ + + public: + + any( + ); + /*! + ensures + - this object is properly initialized + !*/ + + any ( + const any& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + any_function ( + any_function&& item + ); + /*! + ensures + - #item.is_empty() == true + - moves item into *this. + !*/ + + template < typename T > + any ( + const T& item + ); + /*! + ensures + - #contains() == true + - #cast_to() == item + (i.e. a copy of item will be stored in *this) + !*/ + + void clear ( + ); + /*! + ensures + - #*this will have its default value. I.e. #is_empty() == true + !*/ + + template + bool contains ( + ) const; + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + + bool is_empty( + ) const; + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + + template + T& cast_to( + ); + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + const T& cast_to( + ) const; + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + T& get( + ); + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + + any& operator= ( + const any& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + void swap ( + any& item + ); + /*! + ensures + - swaps *this and item + - does not invalidate pointers or references to the object contained + inside *this or item. Moreover, a pointer or reference to the object in + *this will now refer to the contents of #item and vice versa. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + T& any_cast( + any& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T + > + const T& any_cast( + const any& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_ABSTRACT_H_ + + diff --git a/dlib/any/any_decision_function.h b/dlib/any/any_decision_function.h new file mode 100644 index 0000000000000000000000000000000000000000..27eeb4ff35d94b68635c8d3777b622994fdd0c56 --- /dev/null +++ b/dlib/any/any_decision_function.h @@ -0,0 +1,23 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_DECISION_FUNCTION_Hh_ +#define DLIB_AnY_DECISION_FUNCTION_Hh_ + +#include "any_decision_function_abstract.h" +#include "any_function.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + using any_decision_function = any_function; + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_AnY_DECISION_FUNCTION_Hh_ diff --git a/dlib/any/any_decision_function_abstract.h b/dlib/any/any_decision_function_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..756690ccfa15eba7ca5d878d3afc803ba33b6c6f --- /dev/null +++ b/dlib/any/any_decision_function_abstract.h @@ -0,0 +1,24 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ +#ifdef DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ + +#include "any_function_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template + using any_decision_function = any_function; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_DECISION_FUNCTION_ABSTRACT_H_ + + + diff --git a/dlib/any/any_function.h b/dlib/any/any_function.h new file mode 100644 index 0000000000000000000000000000000000000000..b6c9a590f8f505be5dec6ebeff376bd900497411 --- /dev/null +++ b/dlib/any/any_function.h @@ -0,0 +1,126 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_FUNCTION_Hh_ +#define DLIB_AnY_FUNCTION_Hh_ + +#include "../assert.h" +#include "../functional.h" +#include "any.h" +#include "any_function_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + class Storage, + class F + > + class any_function_basic; + + template < + class Storage, + class R, + class... Args + > + class any_function_basic + { + private: + template + using is_valid = std::enable_if_t, any_function_basic>::value && + dlib::is_invocable_r::value, + bool>; + + template + static auto make_invoker() + { + return [](void* self, Args... args) -> R { + return dlib::invoke(*reinterpret_cast>(self), + std::forward(args)...); + }; + } + + Storage str; + R (*func)(void*, Args...) = nullptr; + + public: + + using result_type = R; + + constexpr any_function_basic(std::nullptr_t) noexcept {} + constexpr any_function_basic() = default; + constexpr any_function_basic(const any_function_basic& other) = default; + constexpr any_function_basic& operator=(const any_function_basic& other) = default; + + constexpr any_function_basic(any_function_basic&& other) + : str{std::move(other.str)}, + func{std::exchange(other.func, nullptr)} + { + } + + constexpr any_function_basic& operator=(any_function_basic&& other) + { + if (this != &other) + { + str = std::move(other.str); + func = std::exchange(other.func, nullptr); + } + + return *this; + } + + template = true> + any_function_basic( + F&& f + ) : str{std::forward(f)}, + func{make_invoker()} + { + } + + template = true> + any_function_basic( + F* f + ) : str{f}, + func{make_invoker()} + { + } + + R operator()(Args... args) const { + return func(const_cast(str.get_ptr()), std::forward(args)...); + } + + void clear() { str.clear(); } + void swap (any_function_basic& item) { std::swap(*this, item); } + bool is_empty() const noexcept { return str.is_empty() || func == nullptr; } + bool is_set() const noexcept { return !is_empty(); } + explicit operator bool() const noexcept { return is_set(); } + + template bool contains() const { return str.template contains();} + template T& cast_to() { return str.template cast_to(); } + template const T& cast_to() const { return str.template cast_to(); } + template T& get() { return str.template get(); } + }; + +// ---------------------------------------------------------------------------------------- + + template + T& any_cast(any_function_basic& a) { return a.template cast_to(); } + + template + const T& any_cast(const any_function_basic& a) { return a.template cast_to(); } + +// ---------------------------------------------------------------------------------------- + + template + using any_function = any_function_basic, F>; + + template + using any_function_view = any_function_basic; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_FUNCTION_Hh_ + diff --git a/dlib/any/any_function_abstract.h b/dlib/any/any_function_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..8c0da073659a394a6bb2ae3d4e69ccf54af6b099 --- /dev/null +++ b/dlib/any/any_function_abstract.h @@ -0,0 +1,274 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_FUNCTION_ABSTRACT_H_ +#ifdef DLIB_AnY_FUNCTION_ABSTRACT_H_ + +#include "any_abstract.h" +#include "../algs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + class Storage, + typename function_type + > + class any_function_basic + { + /*! + REQUIREMENTS ON Storage + This must be one of the storage types from dlib/any/storage.hh + E.g. storage_heap, storage_stack, etc. + + It determines the method by which any_function_basic holds onto the function it uses. + + REQUIREMENTS ON function_type + This type should be a function signature. Some examples are: + void (int,int) // a function returning nothing and taking two ints + void () // a function returning nothing and taking no arguments + char (double&) // a function returning a char and taking a reference to a double + + The number of arguments in the function must be no greater than 10. + + INITIAL VALUE + - is_empty() == true + - for all T: contains() == false + + WHAT THIS OBJECT REPRESENTS + This object is a version of dlib::any that is restricted to containing + elements which are some kind of function object with an operator() which + matches the function signature defined by function_type. + + + Here is an example: + #include + #include + #include "dlib/any.h" + using namespace std; + void print_message(string str) { cout << str << endl; } + + int main() + { + dlib::any_function f; + f = print_message; + f("hello world"); // calls print_message("hello world") + } + + Note that any_function_basic objects can be used to store general function + objects (i.e. defined by a class with an overloaded operator()) in + addition to regular global functions. + !*/ + + public: + + // This is the type of object returned by function_type functions. + typedef result_type_for_function_type result_type; + + any_function_basic( + ); + /*! + ensures + - this object is properly initialized + !*/ + + any_function_basic ( + const any_function_basic& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + any_function_basic ( + any_function_basic&& item + ); + /*! + ensures + - moves item into *this. + - The exact move semantics are determined by which Storage type is used. E.g. + storage_heap will result in #item.is_empty()==true but storage_view would result + in #item.is_empty() == false + !*/ + + template < typename Funct > + any_function_basic ( + Funct&& funct + ); + /*! + ensures + - #contains() == true + - #cast_to() == item + (i.e. calling operator() will invoke funct()) + !*/ + + void clear ( + ); + /*! + ensures + - #*this will have its default value. I.e. #is_empty() == true + !*/ + + template + bool contains ( + ) const; + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + + bool is_empty( + ) const; + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + + bool is_set ( + ) const; + /*! + ensures + - returns !is_empty() + !*/ + + explicit operator bool( + ) const; + /*! + ensures + - returns is_set() + !*/ + + result_type operator(Args... args) ( + ) const; + /*! + requires + - is_empty() == false + - the signature defined by function_type takes no arguments + ensures + - Let F denote the function object contained within *this. Then + this function performs: + return F(std::forward(args)...) + !*/ + + template + T& cast_to( + ); + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + const T& cast_to( + ) const; + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + T& get( + ); + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any_function_basic object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + + any_function_basic& operator= ( + const any_function_basic& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that the type of copy is determined by the Storage template argument. E.g. + storage_sbo will result in a deep copy, while storage_view would result in *this + and item referring to the same underlying function. + !*/ + + void swap ( + any_function_basic& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename function_type + > + T& any_cast( + any_function_basic& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename function_type + > + const T& any_cast( + const any_function_basic& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + /*!A any_function + + A version of any_function_basic (defined above) that owns the function it contains. Uses + the small buffer optimization to make working with small lambdas faster. + !*/ + template + using any_function = any_function_basic, F>; + + /*!A any_function_view + + A version of any_function_basic (defined above) that *DOES NOT* own the function it + contains. It merely holds a pointer to the function given to its constructor. + !*/ + template + using any_function_view = any_function_basic; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_FUNCTION_ABSTRACT_H_ + diff --git a/dlib/any/any_trainer.h b/dlib/any/any_trainer.h new file mode 100644 index 0000000000000000000000000000000000000000..cda7e02890d687d3a4b2d323729ebc158b430c10 --- /dev/null +++ b/dlib/any/any_trainer.h @@ -0,0 +1,119 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_AnY_TRAINER_H_ +#define DLIB_AnY_TRAINER_H_ + +#include "any.h" +#include "any_decision_function.h" +#include "any_trainer_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type_, + typename scalar_type_ = double + > + class any_trainer + { + public: + using sample_type = sample_type_; + using scalar_type = scalar_type_; + using mem_manager_type = default_memory_manager; + using trained_function_type = any_decision_function; + + any_trainer() = default; + any_trainer(const any_trainer& other) = default; + any_trainer& operator=(const any_trainer& other) = default; + any_trainer(any_trainer&& other) = default; + any_trainer& operator=(any_trainer&& other) = default; + + template < + class T, + class T_ = std::decay_t, + std::enable_if_t::value, bool> = true + > + any_trainer ( + T&& item + ) : storage{std::forward(item)}, + train_func{[]( + const void* ptr, + const std::vector& samples, + const std::vector& labels + ) -> trained_function_type { + const T_& f = *reinterpret_cast(ptr); + return f.train(samples, labels); + }} + { + } + + template < + class T, + class T_ = std::decay_t, + std::enable_if_t::value, bool> = true + > + any_trainer& operator= ( + T&& item + ) + { + if (contains()) + storage.unsafe_get() = std::forward(item); + else + *this = std::move(any_trainer{std::forward(item)}); + return *this; + } + + trained_function_type train ( + const std::vector& samples, + const std::vector& labels + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_empty() == false, + "\t trained_function_type any_trainer::train()" + << "\n\t You can't call train() on an empty any_trainer" + << "\n\t this: " << this + ); + + return train_func(storage.get_ptr(), samples, labels); + } + + bool is_empty() const { return storage.is_empty(); } + void clear() { storage.clear(); } + void swap (any_trainer& item) { std::swap(*this, item); } + + template bool contains() const { return storage.contains();} + template T& cast_to() { return storage.cast_to(); } + template const T& cast_to() const { return storage.cast_to(); } + template T& get() { return storage.get(); } + + private: + te::storage_heap storage; + trained_function_type (*train_func) ( + const void* self, + const std::vector& samples, + const std::vector& labels + ) = nullptr; + }; + +// ---------------------------------------------------------------------------------------- + + template + T& any_cast(any_trainer& a) { return a.template cast_to(); } + + template + const T& any_cast(const any_trainer& a) { return a.template cast_to(); } + +// ---------------------------------------------------------------------------------------- + +} + + +#endif // DLIB_AnY_TRAINER_H_ + + + + diff --git a/dlib/any/any_trainer_abstract.h b/dlib/any/any_trainer_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..30f013c00eb8d145d7e058b4698972ace6dc903e --- /dev/null +++ b/dlib/any/any_trainer_abstract.h @@ -0,0 +1,243 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_AnY_TRAINER_ABSTRACT_H_ +#ifdef DLIB_AnY_TRAINER_ABSTRACT_H_ + +#include "any_abstract.h" +#include "../algs.h" +#include "any_decision_function_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type_, + typename scalar_type_ = double + > + class any_trainer + { + /*! + INITIAL VALUE + - is_empty() == true + - for all T: contains() == false + + WHAT THIS OBJECT REPRESENTS + This object is a version of dlib::any that is restricted to containing + elements which are some kind of object with a .train() method compatible + with the following signature: + + decision_function train( + const std::vector& samples, + const std::vector& labels + ) const + + Where decision_function is a type capable of being stored in an + any_decision_function object. + + any_trainer is intended to be used to contain objects such as the svm_nu_trainer + and other similar types which represent supervised machine learning algorithms. + It allows you to write code which contains and processes these trainer objects + without needing to know the specific types of trainer objects used. + !*/ + + public: + + typedef sample_type_ sample_type; + typedef scalar_type_ scalar_type; + typedef default_memory_manager mem_manager_type; + typedef any_decision_function trained_function_type; + + any_trainer( + ); + /*! + ensures + - this object is properly initialized + !*/ + + any_trainer ( + const any_trainer& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + any_trainer ( + any_trainer&& item + ); + /*! + ensures + - #item.is_empty() == true + - moves item into *this. + !*/ + + template < typename T > + any_trainer ( + const T& item + ); + /*! + ensures + - #contains() == true + - #cast_to() == item + (i.e. a copy of item will be stored in *this) + !*/ + + void clear ( + ); + /*! + ensures + - #*this will have its default value. I.e. #is_empty() == true + !*/ + + template + bool contains ( + ) const; + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + + bool is_empty( + ) const; + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + + trained_function_type train ( + const std::vector& samples, + const std::vector& labels + ) const + /*! + requires + - is_empty() == false + ensures + - Let TRAINER denote the object contained within *this. Then + this function performs: + return TRAINER.train(samples, labels) + !*/ + + template + T& cast_to( + ); + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + const T& cast_to( + ) const; + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + + template + T& get( + ); + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any_trainer object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + + any_trainer& operator= ( + const any_trainer& item + ); + /*! + ensures + - copies the state of item into *this. + - Note that *this and item will contain independent copies of the + contents of item. That is, this function performs a deep + copy and therefore does not result in *this containing + any kind of reference to item. + !*/ + + void swap ( + any_trainer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename sample_type, + typename scalar_type + > + inline void swap ( + any_trainer& a, + any_trainer& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename sample_type, + typename scalar_type + > + T& any_cast( + any_trainer& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename sample_type, + typename scalar_type + > + const T& any_cast( + const any_trainer& a + ) { return a.cast_to(); } + /*! + ensures + - returns a.cast_to() + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_AnY_TRAINER_ABSTRACT_H_ + + diff --git a/dlib/any/storage.h b/dlib/any/storage.h new file mode 100644 index 0000000000000000000000000000000000000000..6e3a073ca6494c6c62e5153efcc7d2c1a2c86e79 --- /dev/null +++ b/dlib/any/storage.h @@ -0,0 +1,966 @@ +// Copyright (C) 2022 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_TYPE_ERASURE_H_ +#define DLIB_TYPE_ERASURE_H_ + +#include +#include +#include +#include +#include + +namespace dlib +{ + +// ----------------------------------------------------------------------------------------------------- + + class bad_any_cast : public std::bad_cast + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is the exception class used by the storage objects. + It is used to indicate when someone attempts to cast a storage + object into a type which isn't contained in the object. + !*/ + + public: + virtual const char * what() const throw() + { + return "bad_any_cast"; + } + }; + +// ----------------------------------------------------------------------------------------------------- + + namespace te + { + + /*! + This is used as a SFINAE tool to prevent a function taking a universal reference from + binding to some undesired type. For example: + template < + typename T, + T_is_not_this_type = true + > + void foo(T&&); + prevents foo() from binding to an object of type SomeExcludedType. + !*/ + template + using T_is_not_this_type = std::enable_if_t, Storage>::value, bool>; + +// ----------------------------------------------------------------------------------------------------- + + template + class storage_base + { + /*! + WHAT THIS OBJECT REPRESENTS + This class defines functionality common to all type erasure storage objects + (defined below in this file). These objects are essentially type-safe versions of + a void*. In particular, they are containers which can contain only one object + but the object may be of any type. + + Each storage object implements a different way of storing the underlying object. + E.g. on the heap or stack or some other more specialized method. + !*/ + + public: + + bool is_empty() const + /*! + ensures + - if (this object contains any kind of object) then + - returns false + - else + - returns true + !*/ + { + const Storage& me = *static_cast(this); + return me.get_ptr() == nullptr; + } + + template + bool contains() const + /*! + ensures + - if (this object currently contains an object of type T) then + - returns true + - else + - returns false + !*/ + { + const Storage& me = *static_cast(this); + return !is_empty() && me.type_id() == std::type_index{typeid(T)}; + } + + template + T& unsafe_get() + /*! + requires + - contains() == true + ensures + - returns a reference to the object contained within *this. + !*/ + { + Storage& me = *static_cast(this); + return *reinterpret_cast(me.get_ptr()); + } + + template + const T& unsafe_get() const + /*! + requires + - contains() == true + ensures + - returns a const reference to the object contained within *this. + !*/ + { + const Storage& me = *static_cast(this); + return *reinterpret_cast(me.get_ptr()); + } + + template + T& get( + ) + /*! + ensures + - #is_empty() == false + - #contains() == true + - if (contains() == true) + - returns a non-const reference to the object contained in *this. + - else + - Constructs an object of type T inside *this + - Any previous object stored in this any object is destructed and its + state is lost. + - returns a non-const reference to the newly created T object. + !*/ + { + Storage& me = *static_cast(this); + + if (!contains()) + me = T{}; + return unsafe_get(); + } + + template + T& cast_to( + ) + /*! + ensures + - if (contains() == true) then + - returns a non-const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + { + if (!contains()) + throw bad_any_cast{}; + return unsafe_get(); + } + + template + const T& cast_to( + ) const + /*! + ensures + - if (contains() == true) then + - returns a const reference to the object contained within *this + - else + - throws bad_any_cast + !*/ + { + if (!contains()) + throw bad_any_cast{}; + return unsafe_get(); + } + }; + +// ----------------------------------------------------------------------------------------------------- + + class storage_heap : public storage_base + { + public: + /*! + WHAT THIS OBJECT REPRESENTS + This object is a storage type that uses type erasure to erase any type. + + This particular storage type uses heap allocation only. + !*/ + + storage_heap() = default; + /*! + ensures + - #is_empty() == true + - for all T: #contains() == false + !*/ + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true + > + storage_heap(T &&t) noexcept(std::is_nothrow_constructible::value) + /*! + ensures + - copies or moves the incoming object (depending on the forwarding reference) + - #is_empty() == false + - #contains>() == true + - #unsafe_get() will yield the provided t. + !*/ + : ptr{new T_{std::forward(t)}}, + del{[](void *self) { + delete reinterpret_cast(self); + }}, + copy{[](const void *self) -> void * { + return new T_{*reinterpret_cast(self)}; + }}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + } + + storage_heap(const storage_heap& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor. + !*/ + : ptr{other.ptr ? other.copy(other.ptr) : nullptr}, + del{other.del}, + copy{other.copy}, + type_id_{other.type_id_} + { + } + + storage_heap& operator=(const storage_heap& other) + /*! + ensures + - if is_empty() == false then + - destructs the object contained in this class. + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor. + !*/ + { + if (this != &other) + *this = std::move(storage_heap{other}); + return *this; + } + + storage_heap(storage_heap&& other) noexcept + /*! + ensures + - The state of other is moved into *this. + - #other.is_empty() == true + !*/ + : ptr{std::exchange(other.ptr, nullptr)}, + del{std::exchange(other.del, nullptr)}, + copy{std::exchange(other.copy, nullptr)}, + type_id_{std::exchange(other.type_id_, nullptr)} + { + } + + storage_heap& operator=(storage_heap&& other) noexcept + /*! + ensures + - The state of other is moved into *this. + - #other.is_empty() == true + - returns *this + !*/ + { + if (this != &other) + { + clear(); + ptr = std::exchange(other.ptr, nullptr); + del = std::exchange(other.del, nullptr); + copy = std::exchange(other.copy, nullptr); + type_id_ = std::exchange(other.type_id_, nullptr); + } + return *this; + } + + ~storage_heap() + /*! + ensures + - destructs the object contained in *this if one exists. + !*/ + { + if (ptr) + del(ptr); + } + + void clear() + /*! + ensures + - #is_empty() == true + !*/ + { + storage_heap{std::move(*this)}; + } + + void* get_ptr() + /*! + ensures + - returns a pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + const void* get_ptr() const + /*! + ensures + - returns a const pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + std::type_index type_id() const + /*! + requires + - is_empty() == false + ensures + - returns the std::type_index of the type contained within this object. + I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. + !*/ + { + return type_id_(); + } + + private: + void* ptr = nullptr; + void (*del)(void*) = nullptr; + void* (*copy)(const void*) = nullptr; + std::type_index (*type_id_)() = nullptr; + }; + +// ----------------------------------------------------------------------------------------------------- + + template + class storage_stack : public storage_base> + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a storage type that uses type erasure to erase any type. + + This particular storage type uses stack allocation using a template size and alignment. + Therefore, only objects whose size and alignment fits the template parameters can be + erased and absorbed into this object. Attempting to store a type not + representable on the stack with those settings will result in a build error. + !*/ + + public: + storage_stack() = default; + /*! + ensures + - #is_empty() == true + - for all T: #contains() == false + !*/ + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true + > + storage_stack(T &&t) noexcept(std::is_nothrow_constructible::value) + /*! + ensures + - copies or moves the incoming object (depending on the forwarding reference) + - #is_empty() == false + - #contains>() == true + !*/ + : del{[](storage_stack& self) { + reinterpret_cast(&self.data)->~T_(); + self.del = nullptr; + self.copy = nullptr; + self.move = nullptr; + self.type_id_ = nullptr; + }}, + copy{[](const storage_stack& src, storage_stack& dst) { + new (&dst.data) T_{*reinterpret_cast(&src.data)}; + dst.del = src.del; + dst.copy = src.copy; + dst.move = src.move; + dst.type_id_ = src.type_id_; + }}, + move{[](storage_stack& src, storage_stack& dst) { + new (&dst.data) T_{std::move(*reinterpret_cast(&src.data))}; + dst.del = src.del; + dst.copy = src.copy; + dst.move = src.move; + dst.type_id_ = src.type_id_; + }}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + static_assert(sizeof(T_) <= Size, "insufficient size"); + static_assert(Alignment % alignof(T_) == 0, "bad alignment"); + new (&data) T_{std::forward(t)}; + } + + storage_stack(const storage_stack& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor. + !*/ + { + if (other.copy) + other.copy(other, *this); + } + + storage_stack& operator=(const storage_stack& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if is_empty() == false then + - destructs the object contained in this class. + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor + !*/ + { + if (this != &other) + { + clear(); + if (other.copy) + other.copy(other, *this); + } + return *this; + } + + storage_stack(storage_stack&& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is moved using erased type's moved constructor + !*/ + { + if (other.move) + other.move(other, *this); + } + + storage_stack& operator=(storage_stack&& other) + /*! + ensures + - if is_empty() == false then + - destructs the object contained in this class. + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is moved using erased type's moved constructor. + This does not make other empty. It will still contain a moved from object + of the underlying type in whatever that object's moved from state is. + - #other.is_empty() == false + !*/ + { + if (this != &other) + { + clear(); + if (other.move) + other.move(other, *this); + } + return *this; + } + + ~storage_stack() + /*! + ensures + - destructs the object contained in *this if one exists. + !*/ + { + clear(); + } + + void clear() + /*! + ensures + - #is_empty() == true + !*/ + { + if (del) + del(*this); + } + + void* get_ptr() + /*! + ensures + - returns a pointer to the underlying object or nullptr if is_empty() + !*/ + { + return del ? (void*)&data : nullptr; + } + + const void* get_ptr() const + /*! + ensures + - returns a const pointer to the underlying object or nullptr if is_empty() + !*/ + { + return del ? (const void*)&data : nullptr; + } + + std::type_index type_id() const + /*! + requires + - is_empty() == false + ensures + - returns the std::type_index of the type contained within this object. + I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. + !*/ + { + return type_id_(); + } + + private: + std::aligned_storage_t data; + void (*del)(storage_stack&) = nullptr; + void (*copy)(const storage_stack&, storage_stack&) = nullptr; + void (*move)(storage_stack&, storage_stack&) = nullptr; + std::type_index (*type_id_)() = nullptr; + }; + +// ----------------------------------------------------------------------------------------------------- + + template + class storage_sbo : public storage_base> + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a storage type that uses type erasure to erase any type. + + This particular storage type uses small buffer optimization (SBO), i.e. optional + stack allocation if the erased type has sizeof <= Size and alignment + requirements no greater than the given Alignment template value. If not it + allocates the object on the heap. + !*/ + + public: + // type_fits::value tells us if our SBO can hold T. + template + struct type_fits : std::integral_constant{}; + + storage_sbo() = default; + /*! + ensures + - #is_empty() == true + - for all T: #contains() == false + !*/ + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true, + std::enable_if_t::value, bool> = true + > + storage_sbo(T &&t) noexcept(std::is_nothrow_constructible::value) + /*! + ensures + - copies or moves the incoming object (depending on the forwarding reference) + - #is_empty() == false + - #contains>() == true + - stack allocation is used + !*/ + : ptr{new (&data) T_{std::forward(t)}}, + del{[](storage_sbo& self) { + reinterpret_cast(&self.data)->~T_(); + self.ptr = nullptr; + self.del = nullptr; + self.copy = nullptr; + self.move = nullptr; + self.type_id_ = nullptr; + }}, + copy{[](const storage_sbo& src, storage_sbo& dst) { + dst.ptr = new (&dst.data) T_{*reinterpret_cast(src.ptr)}; + dst.del = src.del; + dst.copy = src.copy; + dst.move = src.move; + dst.type_id_ = src.type_id_; + }}, + move{[](storage_sbo& src, storage_sbo& dst) { + dst.ptr = new (&dst.data) T_{std::move(*reinterpret_cast(src.ptr))}; + dst.del = src.del; + dst.copy = src.copy; + dst.move = src.move; + dst.type_id_ = src.type_id_; + }}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + } + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true, + std::enable_if_t::value, bool> = true + > + storage_sbo(T &&t) noexcept(std::is_nothrow_constructible::value) + /*! + ensures + - copies or moves the incoming object (depending on the forwarding reference) + - #is_empty() == false + - #contains>() == true + - heap allocation is used + !*/ + : ptr{new T_{std::forward(t)}}, + del{[](storage_sbo& self) { + delete reinterpret_cast(self.ptr); + self.ptr = nullptr; + self.del = nullptr; + self.copy = nullptr; + self.move = nullptr; + self.type_id_ = nullptr; + }}, + copy{[](const storage_sbo& src, storage_sbo& dst) { + dst.ptr = new T_{*reinterpret_cast(src.ptr)}; + dst.del = src.del; + dst.copy = src.copy; + dst.move = src.move; + dst.type_id_ = src.type_id_; + }}, + move{[](storage_sbo& src, storage_sbo& dst) { + dst.ptr = std::exchange(src.ptr, nullptr); + dst.del = std::exchange(src.del, nullptr); + dst.copy = std::exchange(src.copy, nullptr); + dst.move = std::exchange(src.move, nullptr); + dst.type_id_ = std::exchange(src.type_id_, nullptr); + }}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + } + + storage_sbo(const storage_sbo& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor + !*/ + { + if (other.copy) + other.copy(other, *this); + } + + storage_sbo& operator=(const storage_sbo& other) + /*! + ensures + - if is_empty() == false then + - destructs the object contained in this class. + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - underlying object of other is copied using erased type's copy constructor + !*/ + { + if (this != &other) + { + clear(); + if (other.copy) + other.copy(other, *this); + } + return *this; + } + + storage_sbo(storage_sbo&& other) + /*! + ensures + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - if underlying object of other is allocated on stack then + - underlying object of other is moved using erased type's moved constructor. + This does not make other empty. It will still contain a moved from + object of the underlying type in whatever that object's moved from + state is. + - #other.is_empty() == false + - else + - storage heap pointer is moved. + - #other.is_empty() == true + !*/ + { + if (other.move) + other.move(other, *this); + } + + storage_sbo& operator=(storage_sbo&& other) + /*! + ensures + - underlying object is destructed if is_empty() == false + - #is_empty() == other.is_empty() + - if other.is_empty() == false then + - if underlying object of other is allocated on stack then + - underlying object of other is moved using erased type's moved constructor. + This does not make other empty. It will still contain a moved from + object of the underlying type in whatever that object's moved from + state is. + - #other.is_empty() == false + - else + - storage heap pointer is moved. + - #other.is_empty() == true + !*/ + { + if (this != &other) + { + clear(); + if (other.move) + other.move(other, *this); + } + return *this; + } + + ~storage_sbo() + /*! + ensures + - destructs the object contained in *this if one exists. + !*/ + { + clear(); + } + + void clear() + /*! + ensures + - #is_empty() == true + !*/ + { + if (ptr) + del(*this); + } + + void* get_ptr() + /*! + ensures + - returns a pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + const void* get_ptr() const + /*! + ensures + - returns a const pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + std::type_index type_id() const + /*! + requires + - is_empty() == false + ensures + - returns the std::type_index of the type contained within this object. + I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. + !*/ + { + return type_id_(); + } + + private: + std::aligned_storage_t data; + void* ptr = nullptr; + void (*del)(storage_sbo&) = nullptr; + void (*copy)(const storage_sbo&, storage_sbo&) = nullptr; + void (*move)(storage_sbo&, storage_sbo&) = nullptr; + std::type_index (*type_id_)() = nullptr; + }; + +// ----------------------------------------------------------------------------------------------------- + + class storage_shared : public storage_base + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a storage type that uses type erasure to erase any type. + + This particular storage type uses std::shared_ptr to store and erase + incoming objects. Therefore, it uses heap allocation and reference counting. + Moreover, it has the same copying and move semantics as std::shared_ptr. I.e. + it results in the underlying object being held by reference rather than by + value. + !*/ + + public: + storage_shared() = default; + /*! + ensures + - #is_empty() == true + - for all T: #contains() == false + !*/ + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true + > + storage_shared(T &&t) noexcept(std::is_nothrow_constructible::value) + /*! + ensures + - copies or moves the incoming object (depending on the forwarding reference) + - #is_empty() == true + - #contains>() == true + !*/ + : ptr{std::make_shared(std::forward(t))}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + } + + // This object has the same copy/move semantics as a std::shared_ptr + storage_shared(const storage_shared& other) = default; + storage_shared& operator=(const storage_shared& other) = default; + storage_shared(storage_shared&& other) noexcept = default; + storage_shared& operator=(storage_shared&& other) noexcept = default; + + void clear() + /*! + ensures + - #is_empty() == true + !*/ + { + ptr = nullptr; + type_id_ = nullptr; + } + + void* get_ptr() + /*! + ensures + - returns a pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr.get(); + } + + const void* get_ptr() const + /*! + ensures + - returns a const pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr.get(); + } + + std::type_index type_id() const + /*! + requires + - is_empty() == false + ensures + - returns the std::type_index of the type contained within this object. + I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. + !*/ + { + return type_id_(); + } + + private: + std::shared_ptr ptr = nullptr; + std::type_index (*type_id_)() = nullptr; + }; + +// ----------------------------------------------------------------------------------------------------- + + class storage_view : public storage_base + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a storage type that uses type erasure to erase any type. + + This particular storage type is a view type, similar to std::string_view or + std::span. So underlying objects are only ever referenced, not copied, moved or + destructed. That is, instances of this object take no ownership of the objects + they contain. So they are only valid as long as the contained object exists. + So storage_view merely holds a pointer to the underlying object. + !*/ + + public: + storage_view() = default; + /*! + ensures + - #is_empty() == true + - for all T: #contains() == false + !*/ + + template < + class T, + class T_ = std::decay_t, + T_is_not_this_type = true + > + storage_view(T &&t) noexcept + /*! + ensures + - #get_ptr() == &t + - #is_empty() == false + - #contains>() == true + !*/ + : ptr{&t}, + type_id_{[] { + return std::type_index{typeid(T_)}; + }} + { + } + + // This object has the same copy/move semantics as a void*. + storage_view(const storage_view& other) = default; + storage_view& operator=(const storage_view& other) = default; + storage_view(storage_view&& other) noexcept = default; + storage_view& operator=(storage_view&& other) noexcept = default; + + void clear() + /*! + ensures + - #is_empty() == true + !*/ + { + ptr = nullptr; + type_id_ = nullptr; + } + + void* get_ptr() + /*! + ensures + - returns a pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + const void* get_ptr() const + /*! + ensures + - returns a const pointer to the underlying object or nullptr if is_empty() + !*/ + { + return ptr; + } + + std::type_index type_id() const + /*! + requires + - is_empty() == false + ensures + - returns the std::type_index of the type contained within this object. + I.e. if this object contains the type T then this returns std::type_index{typeid(T)}. + !*/ + { + return type_id_(); + } + + private: + void* ptr = nullptr; + std::type_index (*type_id_)() = nullptr; + }; + +// ----------------------------------------------------------------------------------------------------- + + } +} + +#endif //DLIB_TYPE_ERASURE_H_ diff --git a/dlib/array.h b/dlib/array.h new file mode 100644 index 0000000000000000000000000000000000000000..ecdafc497952fdb4865886a46b4978d1eaebd3ea --- /dev/null +++ b/dlib/array.h @@ -0,0 +1,10 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAy_ +#define DLIB_ARRAy_ + +#include "array/array_kernel.h" +#include "array/array_tools.h" + +#endif // DLIB_ARRAy_ + diff --git a/dlib/array/array_kernel.h b/dlib/array/array_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..0d18d1f45f4c4902fb7f2832cd154779ba2cee8f --- /dev/null +++ b/dlib/array/array_kernel.h @@ -0,0 +1,809 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY_KERNEl_2_ +#define DLIB_ARRAY_KERNEl_2_ + +#include "array_kernel_abstract.h" +#include "../interfaces/enumerable.h" +#include "../algs.h" +#include "../serialize.h" +#include "../sort.h" +#include "../is_kind.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class array : public enumerable + { + + /*! + INITIAL VALUE + - array_size == 0 + - max_array_size == 0 + - array_elements == 0 + - pos == 0 + - last_pos == 0 + - _at_start == true + + CONVENTION + - array_size == size() + - max_array_size == max_size() + - if (max_array_size > 0) + - array_elements == pointer to max_array_size elements of type T + - else + - array_elements == 0 + + - if (array_size > 0) + - last_pos == array_elements + array_size - 1 + - else + - last_pos == 0 + + + - at_start() == _at_start + - current_element_valid() == pos != 0 + - if (current_element_valid()) then + - *pos == element() + !*/ + + public: + + // These typedefs are here for backwards compatibility with old versions of dlib. + typedef array kernel_1a; + typedef array kernel_1a_c; + typedef array kernel_2a; + typedef array kernel_2a_c; + typedef array sort_1a; + typedef array sort_1a_c; + typedef array sort_1b; + typedef array sort_1b_c; + typedef array sort_2a; + typedef array sort_2a_c; + typedef array sort_2b; + typedef array sort_2b_c; + typedef array expand_1a; + typedef array expand_1a_c; + typedef array expand_1b; + typedef array expand_1b_c; + typedef array expand_1c; + typedef array expand_1c_c; + typedef array expand_1d; + typedef array expand_1d_c; + + + + + typedef T type; + typedef T value_type; + typedef mem_manager mem_manager_type; + + array ( + ) : + array_size(0), + max_array_size(0), + array_elements(0), + pos(0), + last_pos(0), + _at_start(true) + {} + + array(const array&) = delete; + array& operator=(array&) = delete; + + array( + array&& item + ) : array() + { + swap(item); + } + + array& operator=( + array&& item + ) + { + swap(item); + return *this; + } + + explicit array ( + size_t new_size + ) : + array_size(0), + max_array_size(0), + array_elements(0), + pos(0), + last_pos(0), + _at_start(true) + { + resize(new_size); + } + + ~array ( + ); + + void clear ( + ); + + inline const T& operator[] ( + size_t pos + ) const; + + inline T& operator[] ( + size_t pos + ); + + void set_size ( + size_t size + ); + + inline size_t max_size( + ) const; + + void set_max_size( + size_t max + ); + + void swap ( + array& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + inline bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + inline const T& element ( + ) const; + + inline T& element ( + ); + + bool move_next ( + ) const; + + void sort ( + ); + + void resize ( + size_t new_size + ); + + const T& back ( + ) const; + + T& back ( + ); + + void pop_back ( + ); + + void pop_back ( + T& item + ); + + void push_back ( + T& item + ); + + void push_back ( + T&& item + ); + + typedef T* iterator; + typedef const T* const_iterator; + iterator begin() { return array_elements; } + const_iterator begin() const { return array_elements; } + iterator end() { return array_elements+array_size; } + const_iterator end() const { return array_elements+array_size; } + + private: + + typename mem_manager::template rebind::other pool; + + // data members + size_t array_size; + size_t max_array_size; + T* array_elements; + + mutable T* pos; + T* last_pos; + mutable bool _at_start; + + }; + + template < + typename T, + typename mem_manager + > + inline void swap ( + array& a, + array& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void serialize ( + const array& item, + std::ostream& out + ) + { + try + { + serialize(item.max_size(),out); + serialize(item.size(),out); + + for (size_t i = 0; i < item.size(); ++i) + serialize(item[i],out); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array"); + } + } + + template < + typename T, + typename mem_manager + > + void deserialize ( + array& item, + std::istream& in + ) + { + try + { + size_t max_size, size; + deserialize(max_size,in); + deserialize(size,in); + item.set_max_size(max_size); + item.set_size(size); + for (size_t i = 0; i < size; ++i) + deserialize(item[i],in); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + array:: + ~array ( + ) + { + if (array_elements) + { + pool.deallocate_array(array_elements); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + clear ( + ) + { + reset(); + last_pos = 0; + array_size = 0; + if (array_elements) + { + pool.deallocate_array(array_elements); + } + array_elements = 0; + max_array_size = 0; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& array:: + operator[] ( + size_t pos + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( pos < this->size() , + "\tconst T& array::operator[]" + << "\n\tpos must < size()" + << "\n\tpos: " << pos + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return array_elements[pos]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& array:: + operator[] ( + size_t pos + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( pos < this->size() , + "\tT& array::operator[]" + << "\n\tpos must be < size()" + << "\n\tpos: " << pos + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return array_elements[pos]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + set_size ( + size_t size + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(( size <= this->max_size() ), + "\tvoid array::set_size" + << "\n\tsize must be <= max_size()" + << "\n\tsize: " << size + << "\n\tmax size: " << this->max_size() + << "\n\tthis: " << this + ); + + reset(); + array_size = size; + if (size > 0) + last_pos = array_elements + size - 1; + else + last_pos = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + size_t array:: + size ( + ) const + { + return array_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + set_max_size( + size_t max + ) + { + reset(); + array_size = 0; + last_pos = 0; + if (max != 0) + { + // if new max size is different + if (max != max_array_size) + { + if (array_elements) + { + pool.deallocate_array(array_elements); + } + // try to get more memroy + try { array_elements = pool.allocate_array(max); } + catch (...) { array_elements = 0; max_array_size = 0; throw; } + max_array_size = max; + } + + } + // if the array is being made to be zero + else + { + if (array_elements) + pool.deallocate_array(array_elements); + max_array_size = 0; + array_elements = 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + size_t array:: + max_size ( + ) const + { + return max_array_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + swap ( + array& item + ) + { + auto array_size_temp = item.array_size; + auto max_array_size_temp = item.max_array_size; + T* array_elements_temp = item.array_elements; + + item.array_size = array_size; + item.max_array_size = max_array_size; + item.array_elements = array_elements; + + array_size = array_size_temp; + max_array_size = max_array_size_temp; + array_elements = array_elements_temp; + + exchange(_at_start,item._at_start); + exchange(pos,item.pos); + exchange(last_pos,item.last_pos); + pool.swap(item.pool); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool array:: + at_start ( + ) const + { + return _at_start; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + reset ( + ) const + { + _at_start = true; + pos = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool array:: + current_element_valid ( + ) const + { + return pos != 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& array:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(this->current_element_valid(), + "\tconst T& array::element()" + << "\n\tThe current element must be valid if you are to access it." + << "\n\tthis: " << this + ); + + return *pos; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& array:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(this->current_element_valid(), + "\tT& array::element()" + << "\n\tThe current element must be valid if you are to access it." + << "\n\tthis: " << this + ); + + return *pos; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + bool array:: + move_next ( + ) const + { + if (!_at_start) + { + if (pos < last_pos) + { + ++pos; + return true; + } + else + { + pos = 0; + return false; + } + } + else + { + _at_start = false; + if (array_size > 0) + { + pos = array_elements; + return true; + } + else + { + return false; + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// Yet more functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + sort ( + ) + { + if (this->size() > 1) + { + // call the quick sort function for arrays that is in algs.h + dlib::qsort_array(*this,0,this->size()-1); + } + this->reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + resize ( + size_t new_size + ) + { + if (this->max_size() < new_size) + { + array temp; + temp.set_max_size(new_size); + temp.set_size(new_size); + for (size_t i = 0; i < this->size(); ++i) + { + exchange((*this)[i],temp[i]); + } + temp.swap(*this); + } + else + { + this->set_size(new_size); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + T& array:: + back ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tT& array::back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return (*this)[this->size()-1]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + const T& array:: + back ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tconst T& array::back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + return (*this)[this->size()-1]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + pop_back ( + T& item + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tvoid array::pop_back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + exchange(item,(*this)[this->size()-1]); + this->set_size(this->size()-1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + pop_back ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( this->size() > 0 , + "\tvoid array::pop_back()" + << "\n\tsize() must be bigger than 0" + << "\n\tsize(): " << this->size() + << "\n\tthis: " << this + ); + + this->set_size(this->size()-1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + push_back ( + T& item + ) + { + if (this->max_size() == this->size()) + { + // double the size of the array + array temp; + temp.set_max_size(this->size()*2 + 1); + temp.set_size(this->size()+1); + for (size_t i = 0; i < this->size(); ++i) + { + exchange((*this)[i],temp[i]); + } + exchange(item,temp[temp.size()-1]); + temp.swap(*this); + } + else + { + this->set_size(this->size()+1); + exchange(item,(*this)[this->size()-1]); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array:: + push_back ( + T&& item + ) { push_back(item); } + +// ---------------------------------------------------------------------------------------- + + template + struct is_array > + { + const static bool value = true; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ARRAY_KERNEl_2_ + diff --git a/dlib/array/array_kernel_abstract.h b/dlib/array/array_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..5cfdd483ac9da2e4b27916d05c5403f87dc6977b --- /dev/null +++ b/dlib/array/array_kernel_abstract.h @@ -0,0 +1,360 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ARRAY_KERNEl_ABSTRACT_ +#ifdef DLIB_ARRAY_KERNEl_ABSTRACT_ + +#include "../interfaces/enumerable.h" +#include "../serialize.h" +#include "../algs.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class array : public enumerable + { + + /*! + REQUIREMENTS ON T + T must have a default constructor. + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + front(), back(), swap(), max_size(), set_size(), and operator[] + functions do not invalidate pointers or references to internal data. + All other functions have no such guarantee. + + INITIAL VALUE + size() == 0 + max_size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements of the array in the + order (*this)[0], (*this)[1], (*this)[2], ... + + WHAT THIS OBJECT REPRESENTS + This object represents an ordered 1-dimensional array of items, + each item is associated with an integer value. The items are + numbered from 0 though size() - 1 and the operator[] functions + run in constant time. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + !*/ + + public: + + typedef T type; + typedef T value_type; + typedef mem_manager mem_manager_type; + + array ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + explicit array ( + size_t new_size + ); + /*! + ensures + - #*this is properly initialized + - #size() == new_size + - #max_size() == new_size + - All elements of the array will have initial values for their type. + throws + - std::bad_alloc or any exception thrown by T's constructor + !*/ + + ~array ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + array( + array&& item + ); + /*! + ensures + - move constructs *this from item. Therefore, the state of item is + moved into *this and #item has a valid but unspecified state. + !*/ + + array& operator=( + array&& item + ); + /*! + ensures + - move assigns *this from item. Therefore, the state of item is + moved into *this and #item has a valid but unspecified state. + - returns a reference to #*this + !*/ + + void clear ( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by T's constructor + if this exception is thrown then the array object is unusable + until clear() is called and succeeds + !*/ + + const T& operator[] ( + size_t pos + ) const; + /*! + requires + - pos < size() + ensures + - returns a const reference to the element at position pos + !*/ + + T& operator[] ( + size_t pos + ); + /*! + requires + - pos < size() + ensures + - returns a non-const reference to the element at position pos + !*/ + + void set_size ( + size_t size + ); + /*! + requires + - size <= max_size() + ensures + - #size() == size + - any element with index between 0 and size - 1 which was in the + array before the call to set_size() retains its value and index. + All other elements have undetermined (but valid for their type) + values. (e.g. this object might buffer old T objects and reuse + them without reinitializing them between calls to set_size()) + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + may throw this exception if there is not enough memory and + if it does throw then the call to set_size() has no effect + !*/ + + size_t max_size( + ) const; + /*! + ensures + - returns the maximum size of *this + !*/ + + void set_max_size( + size_t max + ); + /*! + ensures + - #max_size() == max + - #size() == 0 + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + may throw this exception if there is not enough + memory and if it does throw then max_size() == 0 + !*/ + + void swap ( + array& item + ); + /*! + ensures + - swaps *this and item + !*/ + + void sort ( + ); + /*! + requires + - T must be a type with that is comparable via operator< + ensures + - for all elements in #*this the ith element is <= the i+1 element + - #at_start() == true + throws + - std::bad_alloc or any exception thrown by T's constructor + data may be lost if sort() throws + !*/ + + void resize ( + size_t new_size + ); + /*! + ensures + - #size() == new_size + - #max_size() == max(new_size,max_size()) + - for all i < size() && i < new_size: + - #(*this)[i] == (*this)[i] + (i.e. All the original elements of *this which were at index + values less than new_size are unmodified.) + - for all valid i >= size(): + - #(*this)[i] has an undefined value + (i.e. any new elements of the array have an undefined value) + throws + - std::bad_alloc or any exception thrown by T's constructor. + If an exception is thrown then it has no effect on *this. + !*/ + + + const T& back ( + ) const; + /*! + requires + - size() != 0 + ensures + - returns a const reference to (*this)[size()-1] + !*/ + + T& back ( + ); + /*! + requires + - size() != 0 + ensures + - returns a non-const reference to (*this)[size()-1] + !*/ + + void pop_back ( + T& item + ); + /*! + requires + - size() != 0 + ensures + - #size() == size() - 1 + - swaps (*this)[size()-1] into item + - All elements with an index less than size()-1 are + unmodified by this operation. + !*/ + + void pop_back ( + ); + /*! + requires + - size() != 0 + ensures + - #size() == size() - 1 + - All elements with an index less than size()-1 are + unmodified by this operation. + !*/ + + void push_back ( + T& item + ); + /*! + ensures + - #size() == size()+1 + - swaps item into (*this)[#size()-1] + - #back() == item + - #item has some undefined value (whatever happens to + get swapped out of the array) + throws + - std::bad_alloc or any exception thrown by T's constructor. + If an exception is thrown then it has no effect on *this. + !*/ + + void push_back (T&& item) { push_back(item); } + /*! + enable push_back from rvalues + !*/ + + typedef T* iterator; + typedef const T* const_iterator; + + iterator begin( + ); + /*! + ensures + - returns an iterator that points to the first element in this array or + end() if the array is empty. + !*/ + + const_iterator begin( + ) const; + /*! + ensures + - returns a const iterator that points to the first element in this + array or end() if the array is empty. + !*/ + + iterator end( + ); + /*! + ensures + - returns an iterator that points to one past the end of the array. + !*/ + + const_iterator end( + ) const; + /*! + ensures + - returns a const iterator that points to one past the end of the + array. + !*/ + + private: + + // restricted functions + array(array&); // copy constructor + array& operator=(array&); // assignment operator + + }; + + template < + typename T + > + inline void swap ( + array& a, + array& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T + > + void serialize ( + const array& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + template < + typename T + > + void deserialize ( + array& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +} + +#endif // DLIB_ARRAY_KERNEl_ABSTRACT_ + diff --git a/dlib/array/array_tools.h b/dlib/array/array_tools.h new file mode 100644 index 0000000000000000000000000000000000000000..fce6343968da48774e9de4e3aada851026e1de8c --- /dev/null +++ b/dlib/array/array_tools.h @@ -0,0 +1,38 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY_tOOLS_H_ +#define DLIB_ARRAY_tOOLS_H_ + +#include "../assert.h" +#include "array_tools_abstract.h" + +namespace dlib +{ + template + void split_array ( + T& a, + T& b, + double frac + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0 <= frac && frac <= 1, + "\t void split_array()" + << "\n\t frac must be between 0 and 1." + << "\n\t frac: " << frac + ); + + const unsigned long asize = static_cast(a.size()*frac); + const unsigned long bsize = a.size()-asize; + + b.resize(bsize); + for (unsigned long i = 0; i < b.size(); ++i) + { + swap(b[i], a[i+asize]); + } + a.resize(asize); + } +} + +#endif // DLIB_ARRAY_tOOLS_H_ + diff --git a/dlib/array/array_tools_abstract.h b/dlib/array/array_tools_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..e9b95751891abaee53eeb871fe5255f21c301e31 --- /dev/null +++ b/dlib/array/array_tools_abstract.h @@ -0,0 +1,33 @@ +// Copyright (C) 2013 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ARRAY_tOOLS_ABSTRACT_H_ +#ifdef DLIB_ARRAY_tOOLS_ABSTRACT_H_ + +#include "array_kernel_abstract.h" + +namespace dlib +{ + template + void split_array ( + T& a, + T& b, + double frac + ); + /*! + requires + - 0 <= frac <= 1 + - T must be an array type such as dlib::array or std::vector + ensures + - This function takes the elements of a and splits them into two groups. The + first group remains in a and the second group is put into b. The ordering of + elements in a is preserved. In particular, concatenating #a with #b will + reproduce the original contents of a. + - The elements in a are moved around using global swap(). So they must be + swappable, but do not need to be copyable. + - #a.size() == floor(a.size()*frac) + - #b.size() == a.size()-#a.size() + !*/ +} + +#endif // DLIB_ARRAY_tOOLS_ABSTRACT_H_ + diff --git a/dlib/array2d.h b/dlib/array2d.h new file mode 100644 index 0000000000000000000000000000000000000000..f5325e4a2f2720e5d01c71cf77fea4ee18598bb8 --- /dev/null +++ b/dlib/array2d.h @@ -0,0 +1,12 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2d_ +#define DLIB_ARRAY2d_ + + +#include "array2d/array2d_kernel.h" +#include "array2d/serialize_pixel_overloads.h" +#include "array2d/array2d_generic_image.h" + +#endif // DLIB_ARRAY2d_ + diff --git a/dlib/array2d/array2d_generic_image.h b/dlib/array2d/array2d_generic_image.h new file mode 100644 index 0000000000000000000000000000000000000000..a96f5e3c2548046a195c0045f73ffffda20f8323 --- /dev/null +++ b/dlib/array2d/array2d_generic_image.h @@ -0,0 +1,67 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ +#define DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ + +#include "array2d_kernel.h" +#include "../image_processing/generic_image.h" + +namespace dlib +{ + template + struct image_traits > + { + typedef T pixel_type; + }; + template + struct image_traits > + { + typedef T pixel_type; + }; + + template + inline long num_rows( const array2d& img) { return img.nr(); } + template + inline long num_columns( const array2d& img) { return img.nc(); } + + template + inline void set_image_size( + array2d& img, + long rows, + long cols + ) { img.set_size(rows,cols); } + + template + inline void* image_data( + array2d& img + ) + { + if (img.size() != 0) + return &img[0][0]; + else + return 0; + } + + template + inline const void* image_data( + const array2d& img + ) + { + if (img.size() != 0) + return &img[0][0]; + else + return 0; + } + + template + inline long width_step( + const array2d& img + ) + { + return img.width_step(); + } + +} + +#endif // DLIB_ARRAY2D_GENERIC_iMAGE_Hh_ + diff --git a/dlib/array2d/array2d_kernel.h b/dlib/array2d/array2d_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..b4736ed18e66ab1e6a334d5bb3e99a240b372863 --- /dev/null +++ b/dlib/array2d/array2d_kernel.h @@ -0,0 +1,524 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2D_KERNEl_1_ +#define DLIB_ARRAY2D_KERNEl_1_ + +#include "array2d_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../serialize.h" +#include "../geometry/rectangle.h" + +namespace dlib +{ + template < + typename T, + typename mem_manager = default_memory_manager + > + class array2d : public enumerable + { + + /*! + INITIAL VALUE + - nc_ == 0 + - nr_ == 0 + - data == 0 + - at_start_ == true + - cur == 0 + - last == 0 + + CONVENTION + - nc_ == nc() + - nr_ == nc() + - if (data != 0) then + - last == a pointer to the last element in the data array + - data == pointer to an array of nc_*nr_ T objects + - else + - nc_ == 0 + - nr_ == 0 + - data == 0 + - last == 0 + + + - nr_ * nc_ == size() + - if (cur == 0) then + - current_element_valid() == false + - else + - current_element_valid() == true + - *cur == element() + + - at_start_ == at_start() + !*/ + + + class row_helper; + public: + + // These typedefs are here for backwards compatibility with older versions of dlib. + typedef array2d kernel_1a; + typedef array2d kernel_1a_c; + + typedef T type; + typedef mem_manager mem_manager_type; + typedef T* iterator; + typedef const T* const_iterator; + + + // ----------------------------------- + + class row + { + /*! + CONVENTION + - nc_ == nc() + - for all x < nc_: + - (*this)[x] == data[x] + !*/ + + friend class array2d; + friend class row_helper; + + public: + long nc ( + ) const { return nc_; } + + const T& operator[] ( + long column + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(column < nc() && column >= 0, + "\tconst T& array2d::operator[](long column) const" + << "\n\tThe column index given must be less than the number of columns." + << "\n\tthis: " << this + << "\n\tcolumn: " << column + << "\n\tnc(): " << nc() + ); + + return data[column]; + } + + T& operator[] ( + long column + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(column < nc() && column >= 0, + "\tT& array2d::operator[](long column)" + << "\n\tThe column index given must be less than the number of columns." + << "\n\tthis: " << this + << "\n\tcolumn: " << column + << "\n\tnc(): " << nc() + ); + + return data[column]; + } + + private: + + row(T* data_, long cols) : data(data_), nc_(cols) {} + row(row&& r) = default; + row& operator=(row&& r) = default; + + T* data = nullptr; + long nc_ = 0; + + + // restricted functions + row(const row&) = delete; + row& operator=(const row&) = delete; + }; + + // ----------------------------------- + + array2d ( + ) : + data(0), + nc_(0), + nr_(0), + cur(0), + last(0), + at_start_(true) + { + } + + array2d( + long rows, + long cols + ) : + data(0), + nc_(0), + nr_(0), + cur(0), + last(0), + at_start_(true) + { + // make sure requires clause is not broken + DLIB_ASSERT((cols >= 0 && rows >= 0), + "\t array2d::array2d(long rows, long cols)" + << "\n\t The array2d can't have negative rows or columns." + << "\n\t this: " << this + << "\n\t cols: " << cols + << "\n\t rows: " << rows + ); + + set_size(rows,cols); + } + + array2d(const array2d&) = delete; // copy constructor + array2d& operator=(const array2d&) = delete; // assignment operator + +#ifdef DLIB_HAS_RVALUE_REFERENCES + array2d(array2d&& item) : array2d() + { + swap(item); + } + + array2d& operator= ( + array2d&& rhs + ) + { + swap(rhs); + return *this; + } +#endif + + virtual ~array2d ( + ) { clear(); } + + long nc ( + ) const { return nc_; } + + long nr ( + ) const { return nr_; } + + row operator[] ( + long row_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(row_ < nr() && row_ >= 0, + "\trow array2d::operator[](long row_)" + << "\n\tThe row index given must be less than the number of rows." + << "\n\tthis: " << this + << "\n\trow_: " << row_ + << "\n\tnr(): " << nr() + ); + + return row(data+row_*nc_, nc_); + } + + const row operator[] ( + long row_ + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(row_ < nr() && row_ >= 0, + "\tconst row array2d::operator[](long row_) const" + << "\n\tThe row index given must be less than the number of rows." + << "\n\tthis: " << this + << "\n\trow_: " << row_ + << "\n\tnr(): " << nr() + ); + + return row(data+row_*nc_, nc_); + } + + void swap ( + array2d& item + ) + { + exchange(data,item.data); + exchange(nr_,item.nr_); + exchange(nc_,item.nc_); + exchange(at_start_,item.at_start_); + exchange(cur,item.cur); + exchange(last,item.last); + pool.swap(item.pool); + } + + void clear ( + ) + { + if (data != 0) + { + pool.deallocate_array(data); + nc_ = 0; + nr_ = 0; + data = 0; + at_start_ = true; + cur = 0; + last = 0; + } + } + + void set_size ( + long rows, + long cols + ); + + bool at_start ( + ) const { return at_start_; } + + void reset ( + ) const { at_start_ = true; cur = 0; } + + bool current_element_valid ( + ) const { return (cur != 0); } + + const T& element ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tconst T& array2d::element()()" + << "\n\tYou can only call element() when you are at a valid one." + << "\n\tthis: " << this + ); + + return *cur; + } + + T& element ( + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tT& array2d::element()()" + << "\n\tYou can only call element() when you are at a valid one." + << "\n\tthis: " << this + ); + + return *cur; + } + + bool move_next ( + ) const + { + if (cur != 0) + { + if (cur != last) + { + ++cur; + return true; + } + cur = 0; + return false; + } + else if (at_start_) + { + cur = data; + at_start_ = false; + return (data != 0); + } + else + { + return false; + } + } + + size_t size ( + ) const { return static_cast(nc_) * static_cast(nr_); } + + long width_step ( + ) const + { + return nc_*sizeof(T); + } + + iterator begin() + { + return data; + } + + iterator end() + { + return data+size(); + } + + const_iterator begin() const + { + return data; + } + + const_iterator end() const + { + return data+size(); + } + + + private: + + + T* data; + long nc_; + long nr_; + + typename mem_manager::template rebind::other pool; + mutable T* cur; + T* last; + mutable bool at_start_; + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + inline void swap ( + array2d& a, + array2d& b + ) { a.swap(b); } + + + template < + typename T, + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + item.reset(); + while (item.move_next()) + serialize(item.element(),out); + item.reset(); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename T, + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + item.set_size(nr,nc); + + while (item.move_next()) + deserialize(item.element(),in); + item.reset(); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename T, + typename mem_manager + > + void array2d:: + set_size ( + long rows, + long cols + ) + { + // make sure requires clause is not broken + DLIB_ASSERT((cols >= 0 && rows >= 0) , + "\tvoid array2d::set_size(long rows, long cols)" + << "\n\tThe array2d can't have negative rows or columns." + << "\n\tthis: " << this + << "\n\tcols: " << cols + << "\n\trows: " << rows + ); + + // set the enumerator back at the start + at_start_ = true; + cur = 0; + + // don't do anything if we are already the right size. + if (nc_ == cols && nr_ == rows) + { + return; + } + + nc_ = cols; + nr_ = rows; + + // free any existing memory + if (data != 0) + { + pool.deallocate_array(data); + data = 0; + } + + // now setup this object to have the new size + try + { + if (nr_ > 0) + { + data = pool.allocate_array(nr_*nc_); + last = data + nr_*nc_ - 1; + } + } + catch (...) + { + if (data) + pool.deallocate_array(data); + + data = 0; + nc_ = 0; + nr_ = 0; + last = 0; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template + struct is_array2d > + { + const static bool value = true; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ARRAY2D_KERNEl_1_ + diff --git a/dlib/array2d/array2d_kernel_abstract.h b/dlib/array2d/array2d_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..daccfc600436f2a6f59ea58f37fc67537a3d94dd --- /dev/null +++ b/dlib/array2d/array2d_kernel_abstract.h @@ -0,0 +1,339 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_ARRAY2D_KERNEl_ABSTRACT_ +#ifdef DLIB_ARRAY2D_KERNEl_ABSTRACT_ + +#include "../interfaces/enumerable.h" +#include "../serialize.h" +#include "../algs.h" +#include "../geometry/rectangle_abstract.h" + +namespace dlib +{ + + template < + typename T, + typename mem_manager = default_memory_manager + > + class array2d : public enumerable + { + + /*! + REQUIREMENTS ON T + T must have a default constructor. + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + POINTERS AND REFERENCES TO INTERNAL DATA + No member functions in this object will invalidate pointers + or references to internal data except for the set_size() + and clear() member functions. + + INITIAL VALUE + nr() == 0 + nc() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the elements of the array starting + with row 0 and then proceeding to row 1 and so on. Each row will be + fully enumerated before proceeding on to the next row and the elements + in a row will be enumerated beginning with the 0th column, then the 1st + column and so on. + + WHAT THIS OBJECT REPRESENTS + This object represents a 2-Dimensional array of objects of + type T. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + + Finally, note that this object stores its data contiguously and in + row major order. Moreover, there is no padding at the end of each row. + This means that its width_step() value is always equal to sizeof(type)*nc(). + !*/ + + + public: + + // ---------------------------------------- + + typedef T type; + typedef mem_manager mem_manager_type; + typedef T* iterator; + typedef const T* const_iterator; + + // ---------------------------------------- + + class row + { + /*! + POINTERS AND REFERENCES TO INTERNAL DATA + No member functions in this object will invalidate pointers + or references to internal data. + + WHAT THIS OBJECT REPRESENTS + This object represents a row of Ts in an array2d object. + !*/ + public: + long nc ( + ) const; + /*! + ensures + - returns the number of columns in this row + !*/ + + const T& operator[] ( + long column + ) const; + /*! + requires + - 0 <= column < nc() + ensures + - returns a const reference to the T in the given column + !*/ + + T& operator[] ( + long column + ); + /*! + requires + - 0 <= column < nc() + ensures + - returns a non-const reference to the T in the given column + !*/ + + private: + // restricted functions + row(); + row& operator=(row&); + }; + + // ---------------------------------------- + + array2d ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + array2d(const array2d&) = delete; // copy constructor + array2d& operator=(const array2d&) = delete; // assignment operator + + array2d( + array2d&& item + ); + /*! + ensures + - Moves the state of item into *this. + - #item is in a valid but unspecified state. + !*/ + + array2d ( + long rows, + long cols + ); + /*! + requires + - rows >= 0 && cols >= 0 + ensures + - #nc() == cols + - #nr() == rows + - #at_start() == true + - all elements in this array have initial values for their type + throws + - std::bad_alloc + !*/ + + virtual ~array2d ( + ); + /*! + ensures + - all resources associated with *this has been released + !*/ + + void clear ( + ); + /*! + ensures + - #*this has an initial value for its type + !*/ + + long nc ( + ) const; + /*! + ensures + - returns the number of elements there are in a row. i.e. returns + the number of columns in *this + !*/ + + long nr ( + ) const; + /*! + ensures + - returns the number of rows in *this + !*/ + + void set_size ( + long rows, + long cols + ); + /*! + requires + - rows >= 0 && cols >= 0 + ensures + - #nc() == cols + - #nr() == rows + - #at_start() == true + - if (the call to set_size() doesn't change the dimensions of this array) then + - all elements in this array retain their values from before this function was called + - else + - all elements in this array have initial values for their type + throws + - std::bad_alloc + If this exception is thrown then #*this will have an initial + value for its type. + !*/ + + row operator[] ( + long row_index + ); + /*! + requires + - 0 <= row_index < nr() + ensures + - returns a non-const row of nc() elements that represents the + given row_index'th row in *this. + !*/ + + const row operator[] ( + long row_index + ) const; + /*! + requires + - 0 <= row_index < nr() + ensures + - returns a const row of nc() elements that represents the + given row_index'th row in *this. + !*/ + + void swap ( + array2d& item + ); + /*! + ensures + - swaps *this and item + !*/ + + array2d& operator= ( + array2d&& rhs + ); + /*! + ensures + - Moves the state of item into *this. + - #item is in a valid but unspecified state. + - returns #*this + !*/ + + long width_step ( + ) const; + /*! + ensures + - returns the size of one row of the image, in bytes. + More precisely, return a number N such that: + (char*)&item[0][0] + N == (char*)&item[1][0]. + - for dlib::array2d objects, the returned value + is always equal to sizeof(type)*nc(). However, + other objects which implement dlib::array2d style + interfaces might have padding at the ends of their + rows and therefore might return larger numbers. + An example of such an object is the dlib::cv_image. + !*/ + + iterator begin( + ); + /*! + ensures + - returns a random access iterator pointing to the first element in this + object. + - The iterator will iterate over the elements of the object in row major + order. + !*/ + + iterator end( + ); + /*! + ensures + - returns a random access iterator pointing to one past the end of the last + element in this object. + !*/ + + const_iterator begin( + ) const; + /*! + ensures + - returns a random access iterator pointing to the first element in this + object. + - The iterator will iterate over the elements of the object in row major + order. + !*/ + + const_iterator end( + ) const; + /*! + ensures + - returns a random access iterator pointing to one past the end of the last + element in this object. + !*/ + + }; + + template < + typename T, + typename mem_manager + > + inline void swap ( + array2d& a, + array2d& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename T, + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ); + /*! + Provides serialization support. Note that the serialization formats used by the + dlib::matrix and dlib::array2d objects are compatible. That means you can load the + serialized data from one into another and it will work properly. + !*/ + + template < + typename T, + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +} + +#endif // DLIB_ARRAY2D_KERNEl_ABSTRACT_ + diff --git a/dlib/array2d/serialize_pixel_overloads.h b/dlib/array2d/serialize_pixel_overloads.h new file mode 100644 index 0000000000000000000000000000000000000000..91383a66654553b2356a677e848f2c772a72aa22 --- /dev/null +++ b/dlib/array2d/serialize_pixel_overloads.h @@ -0,0 +1,371 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ +#define DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ + +#include "array2d_kernel.h" +#include "../pixel.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + /* + This file contains overloads of the serialize functions for array2d object + for the case where they contain simple 8bit POD pixel types. In these + cases we can perform a much faster serialization by writing data in chunks + instead of one pixel at a time (this avoids a lot of function call overhead + inside the iostreams). + */ + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(rgb_pixel)*item.size()); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(rgb_pixel) == 3); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(rgb_pixel)*item.size()); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(bgr_pixel)*item.size()); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(bgr_pixel) == 3); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(bgr_pixel)*item.size()); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(hsi_pixel)*item.size()); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(hsi_pixel) == 3); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(hsi_pixel)*item.size()); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size()); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + COMPILE_TIME_ASSERT(sizeof(rgb_alpha_pixel) == 4); + + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(rgb_alpha_pixel)*item.size()); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mem_manager + > + void serialize ( + const array2d& item, + std::ostream& out + ) + { + try + { + // The reason the serialization is a little funny is because we are trying to + // maintain backwards compatibility with an older serialization format used by + // dlib while also encoding things in a way that lets the array2d and matrix + // objects have compatible serialization formats. + serialize(-item.nr(),out); + serialize(-item.nc(),out); + + if (item.size() != 0) + out.write((char*)&item[0][0], sizeof(unsigned char)*item.size()); + } + catch (serialization_error& e) + { + throw serialization_error(e.info + "\n while serializing object of type array2d"); + } + } + + template < + typename mem_manager + > + void deserialize ( + array2d& item, + std::istream& in + ) + { + try + { + long nr, nc; + deserialize(nr,in); + deserialize(nc,in); + // this is the newer serialization format + if (nr < 0 || nc < 0) + { + nr *= -1; + nc *= -1; + } + else + { + std::swap(nr,nc); + } + + + item.set_size(nr,nc); + + if (item.size() != 0) + in.read((char*)&item[0][0], sizeof(unsigned char)*item.size()); + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type array2d"); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_ARRAY2D_SERIALIZE_PIXEL_OvERLOADS_Hh_ + diff --git a/dlib/assert.h b/dlib/assert.h new file mode 100644 index 0000000000000000000000000000000000000000..67dc634144130757bd61fa5e9668ff822c1f7e11 --- /dev/null +++ b/dlib/assert.h @@ -0,0 +1,217 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_ASSERt_ +#define DLIB_ASSERt_ + +#include "config.h" +#include +#include +#include "error.h" + +// ----------------------------- + +// Use some stuff from boost here +// (C) Copyright John Maddock 2001 - 2003. +// (C) Copyright Darin Adler 2001. +// (C) Copyright Peter Dimov 2001. +// (C) Copyright Bill Kempf 2002. +// (C) Copyright Jens Maurer 2002. +// (C) Copyright David Abrahams 2002 - 2003. +// (C) Copyright Gennaro Prota 2003. +// (C) Copyright Eric Friedman 2003. +// License: Boost Software License See LICENSE.txt for the full license. +// +#ifndef DLIB_BOOST_JOIN +#define DLIB_BOOST_JOIN( X, Y ) DLIB_BOOST_DO_JOIN( X, Y ) +#define DLIB_BOOST_DO_JOIN( X, Y ) DLIB_BOOST_DO_JOIN2(X,Y) +#define DLIB_BOOST_DO_JOIN2( X, Y ) X##Y +#endif + +// figure out if the compiler has rvalue references. +#if defined(__clang__) +# if __has_feature(cxx_rvalue_references) +# define DLIB_HAS_RVALUE_REFERENCES +# endif +# if __has_feature(cxx_generalized_initializers) +# define DLIB_HAS_INITIALIZER_LISTS +# endif +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__) +# define DLIB_HAS_RVALUE_REFERENCES +# define DLIB_HAS_INITIALIZER_LISTS +#elif defined(_MSC_VER) && _MSC_VER >= 1800 +# define DLIB_HAS_INITIALIZER_LISTS +# define DLIB_HAS_RVALUE_REFERENCES +#elif defined(_MSC_VER) && _MSC_VER >= 1600 +# define DLIB_HAS_RVALUE_REFERENCES +#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X) +# define DLIB_HAS_RVALUE_REFERENCES +# define DLIB_HAS_INITIALIZER_LISTS +#endif + +#if defined(__APPLE__) && defined(__GNUC_LIBSTD__) && ((__GNUC_LIBSTD__-0) * 100 + __GNUC_LIBSTD_MINOR__-0 <= 402) + // Apple has not updated libstdc++ in some time and anything under 4.02 does not have for sure. +# undef DLIB_HAS_INITIALIZER_LISTS +#endif + +// figure out if the compiler has static_assert. +#if defined(__clang__) +# if __has_feature(cxx_static_assert) +# define DLIB_HAS_STATIC_ASSERT +# endif +#elif defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ > 2)) && defined(__GXX_EXPERIMENTAL_CXX0X__) +# define DLIB_HAS_STATIC_ASSERT +#elif defined(_MSC_VER) && _MSC_VER >= 1600 +# define DLIB_HAS_STATIC_ASSERT +#elif defined(__INTEL_COMPILER) && defined(BOOST_INTEL_STDCXX0X) +# define DLIB_HAS_STATIC_ASSERT +#endif + + +// ----------------------------- + +namespace dlib +{ + template struct compile_time_assert; + template <> struct compile_time_assert { enum {value=1}; }; + + template struct assert_are_same_type; + template struct assert_are_same_type {enum{value=1};}; + template struct assert_are_not_same_type {enum{value=1}; }; + template struct assert_are_not_same_type {}; + + template struct assert_types_match {enum{value=0};}; + template struct assert_types_match {enum{value=1};}; +} + + +// gcc 4.8 will warn about unused typedefs. But we use typedefs in some of the compile +// time assert macros so we need to make it not complain about them "not being used". +#ifdef __GNUC__ +#define DLIB_NO_WARN_UNUSED __attribute__ ((unused)) +#else +#define DLIB_NO_WARN_UNUSED +#endif + +// Use the newer static_assert if it's available since it produces much more readable error +// messages. +#ifdef DLIB_HAS_STATIC_ASSERT + #define COMPILE_TIME_ASSERT(expression) static_assert(expression, "Failed assertion") + #define ASSERT_ARE_SAME_TYPE(type1, type2) static_assert(::dlib::assert_types_match::value, "These types should be the same but aren't.") + #define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) static_assert(!::dlib::assert_types_match::value, "These types should NOT be the same.") +#else + #define COMPILE_TIME_ASSERT(expression) \ + DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_CTA, __LINE__)[::dlib::compile_time_assert<(bool)(expression)>::value] + + #define ASSERT_ARE_SAME_TYPE(type1, type2) \ + DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_AAST, __LINE__)[::dlib::assert_are_same_type::value] + + #define ASSERT_ARE_NOT_SAME_TYPE(type1, type2) \ + DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DLIB_AANST, __LINE__)[::dlib::assert_are_not_same_type::value] +#endif + +// ----------------------------- + +#if defined DLIB_DISABLE_ASSERTS + // if DLIB_DISABLE_ASSERTS is on then never enable DLIB_ASSERT no matter what. + #undef ENABLE_ASSERTS +#endif + +#if !defined(DLIB_DISABLE_ASSERTS) && ( defined DEBUG || defined _DEBUG) + // make sure ENABLE_ASSERTS is defined if we are indeed using them. + #ifndef ENABLE_ASSERTS + #define ENABLE_ASSERTS + #endif +#endif + +// ----------------------------- + +#ifdef __GNUC__ +// There is a bug in version 4.4.5 of GCC on Ubuntu which causes GCC to segfault +// when __PRETTY_FUNCTION__ is used within certain templated functions. So just +// don't use it with this version of GCC. +# if !(__GNUC__ == 4 && __GNUC_MINOR__ == 4 && __GNUC_PATCHLEVEL__ == 5) +# define DLIB_FUNCTION_NAME __PRETTY_FUNCTION__ +# else +# define DLIB_FUNCTION_NAME "unknown function" +# endif +#elif defined(_MSC_VER) +#define DLIB_FUNCTION_NAME __FUNCSIG__ +#else +#define DLIB_FUNCTION_NAME "unknown function" +#endif + +#define DLIBM_CASSERT(_exp,_message) \ + {if ( !(_exp) ) \ + { \ + dlib_assert_breakpoint(); \ + std::ostringstream dlib_o_out; \ + dlib_o_out << "\n\nError detected at line " << __LINE__ << ".\n"; \ + dlib_o_out << "Error detected in file " << __FILE__ << ".\n"; \ + dlib_o_out << "Error detected in function " << DLIB_FUNCTION_NAME << ".\n\n"; \ + dlib_o_out << "Failing expression was " << #_exp << ".\n"; \ + dlib_o_out << std::boolalpha << _message << "\n"; \ + throw dlib::fatal_error(dlib::EBROKEN_ASSERT,dlib_o_out.str()); \ + }} + +// This macro is not needed if you have a real C++ compiler. It's here to work around bugs in Visual Studio's preprocessor. +#define DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(x) x +// Make it so the 2nd argument of DLIB_CASSERT is optional. That is, you can call it like +// DLIB_CASSERT(exp) or DLIB_CASSERT(exp,message). +#define DLIBM_CASSERT_1_ARGS(exp) DLIBM_CASSERT(exp,"") +#define DLIBM_CASSERT_2_ARGS(exp,message) DLIBM_CASSERT(exp,message) +#define DLIBM_GET_3TH_ARG(arg1, arg2, arg3, ...) arg3 +#define DLIBM_CASSERT_CHOOSER(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_GET_3TH_ARG(__VA_ARGS__, DLIBM_CASSERT_2_ARGS, DLIBM_CASSERT_1_ARGS, DLIB_CASSERT_NEVER_USED)) +#define DLIB_CASSERT(...) DLIB_WORKAROUND_VISUAL_STUDIO_BUGS(DLIBM_CASSERT_CHOOSER(__VA_ARGS__)(__VA_ARGS__)) + + +#ifdef ENABLE_ASSERTS + #define DLIB_ASSERT(...) DLIB_CASSERT(__VA_ARGS__) + #define DLIB_IF_ASSERT(exp) exp +#else + #define DLIB_ASSERT(...) {} + #define DLIB_IF_ASSERT(exp) +#endif + +// ---------------------------------------------------------------------------------------- + + /*!A DLIB_ASSERT_HAS_STANDARD_LAYOUT + + This macro is meant to cause a compiler error if a type doesn't have a simple + memory layout (like a C struct). In particular, types with simple layouts are + ones which can be copied via memcpy(). + + + This was called a POD type in C++03 and in C++0x we are looking to check if + it is a "standard layout type". Once we can use C++0x we can change this macro + to something that uses the std::is_standard_layout type_traits class. + See: http://www2.research.att.com/~bs/C++0xFAQ.html#PODs + !*/ + // Use the fact that in C++03 you can't put non-PODs into a union. +#define DLIB_ASSERT_HAS_STANDARD_LAYOUT(type) \ + union DLIB_BOOST_JOIN(DAHSL_,__LINE__) { type TYPE_NOT_STANDARD_LAYOUT; }; \ + DLIB_NO_WARN_UNUSED typedef char DLIB_BOOST_JOIN(DAHSL2_,__LINE__)[sizeof(DLIB_BOOST_JOIN(DAHSL_,__LINE__))]; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +// breakpoints +extern "C" +{ + inline void dlib_assert_breakpoint( + ) {} + /*! + ensures + - this function does nothing + It exists just so you can put breakpoints on it in a debugging tool. + It is called only when an DLIB_ASSERT or DLIB_CASSERT fails and is about to + throw an exception. + !*/ +} + +// ----------------------------- + +#include "stack_trace.h" + +#endif // DLIB_ASSERt_ + diff --git a/dlib/base64.h b/dlib/base64.h new file mode 100644 index 0000000000000000000000000000000000000000..8308920d6e133a403a08160cde5bd5025408f2c9 --- /dev/null +++ b/dlib/base64.h @@ -0,0 +1,9 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BASe64_ +#define DLIB_BASe64_ + +#include "base64/base64_kernel_1.h" + +#endif // DLIB_BASe64_ + diff --git a/dlib/base64/base64_kernel_1.cpp b/dlib/base64/base64_kernel_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5b48c789e8baa9e15681e02d35b28cdf64e711be --- /dev/null +++ b/dlib/base64/base64_kernel_1.cpp @@ -0,0 +1,403 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BASE64_KERNEL_1_CPp_ +#define DLIB_BASE64_KERNEL_1_CPp_ + +#include "base64_kernel_1.h" +#include +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + base64::line_ending_type base64:: + line_ending ( + ) const + { + return eol_style; + } + +// ---------------------------------------------------------------------------------------- + + void base64:: + set_line_ending ( + line_ending_type eol_style_ + ) + { + eol_style = eol_style_; + } + +// ---------------------------------------------------------------------------------------- + + base64:: + base64 ( + ) : + encode_table(0), + decode_table(0), + bad_value(100), + eol_style(LF) + { + try + { + encode_table = new char[64]; + decode_table = new unsigned char[UCHAR_MAX]; + } + catch (...) + { + if (encode_table) delete [] encode_table; + if (decode_table) delete [] decode_table; + throw; + } + + // now set up the tables with the right stuff + encode_table[0] = 'A'; + encode_table[17] = 'R'; + encode_table[34] = 'i'; + encode_table[51] = 'z'; + + encode_table[1] = 'B'; + encode_table[18] = 'S'; + encode_table[35] = 'j'; + encode_table[52] = '0'; + + encode_table[2] = 'C'; + encode_table[19] = 'T'; + encode_table[36] = 'k'; + encode_table[53] = '1'; + + encode_table[3] = 'D'; + encode_table[20] = 'U'; + encode_table[37] = 'l'; + encode_table[54] = '2'; + + encode_table[4] = 'E'; + encode_table[21] = 'V'; + encode_table[38] = 'm'; + encode_table[55] = '3'; + + encode_table[5] = 'F'; + encode_table[22] = 'W'; + encode_table[39] = 'n'; + encode_table[56] = '4'; + + encode_table[6] = 'G'; + encode_table[23] = 'X'; + encode_table[40] = 'o'; + encode_table[57] = '5'; + + encode_table[7] = 'H'; + encode_table[24] = 'Y'; + encode_table[41] = 'p'; + encode_table[58] = '6'; + + encode_table[8] = 'I'; + encode_table[25] = 'Z'; + encode_table[42] = 'q'; + encode_table[59] = '7'; + + encode_table[9] = 'J'; + encode_table[26] = 'a'; + encode_table[43] = 'r'; + encode_table[60] = '8'; + + encode_table[10] = 'K'; + encode_table[27] = 'b'; + encode_table[44] = 's'; + encode_table[61] = '9'; + + encode_table[11] = 'L'; + encode_table[28] = 'c'; + encode_table[45] = 't'; + encode_table[62] = '+'; + + encode_table[12] = 'M'; + encode_table[29] = 'd'; + encode_table[46] = 'u'; + encode_table[63] = '/'; + + encode_table[13] = 'N'; + encode_table[30] = 'e'; + encode_table[47] = 'v'; + + encode_table[14] = 'O'; + encode_table[31] = 'f'; + encode_table[48] = 'w'; + + encode_table[15] = 'P'; + encode_table[32] = 'g'; + encode_table[49] = 'x'; + + encode_table[16] = 'Q'; + encode_table[33] = 'h'; + encode_table[50] = 'y'; + + + + // we can now fill out the decode_table by using the encode_table + for (int i = 0; i < UCHAR_MAX; ++i) + { + decode_table[i] = bad_value; + } + for (unsigned char i = 0; i < 64; ++i) + { + decode_table[(unsigned char)encode_table[i]] = i; + } + } + +// ---------------------------------------------------------------------------------------- + + base64:: + ~base64 ( + ) + { + delete [] encode_table; + delete [] decode_table; + } + +// ---------------------------------------------------------------------------------------- + + void base64:: + encode ( + std::istream& in_, + std::ostream& out_ + ) const + { + using namespace std; + streambuf& in = *in_.rdbuf(); + streambuf& out = *out_.rdbuf(); + + unsigned char inbuf[3]; + unsigned char outbuf[4]; + streamsize status = in.sgetn(reinterpret_cast(&inbuf),3); + + unsigned char c1, c2, c3, c4, c5, c6; + + int counter = 19; + + // while we haven't hit the end of the input stream + while (status != 0) + { + if (counter == 0) + { + counter = 19; + // write a newline + char ch; + switch (eol_style) + { + case CR: + ch = '\r'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + break; + case LF: + ch = '\n'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + break; + case CRLF: + ch = '\r'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + ch = '\n'; + if (out.sputn(&ch,1)!=1) + throw std::ios_base::failure("error occurred in the base64 object"); + break; + default: + DLIB_CASSERT(false,"this should never happen"); + } + } + --counter; + + if (status == 3) + { + // encode the bytes in inbuf to base64 and write them to the output stream + c1 = inbuf[0]&0xfc; + c2 = inbuf[0]&0x03; + c3 = inbuf[1]&0xf0; + c4 = inbuf[1]&0x0f; + c5 = inbuf[2]&0xc0; + c6 = inbuf[2]&0x3f; + + outbuf[0] = c1>>2; + outbuf[1] = (c2<<4)|(c3>>4); + outbuf[2] = (c4<<2)|(c5>>6); + outbuf[3] = c6; + + + outbuf[0] = encode_table[outbuf[0]]; + outbuf[1] = encode_table[outbuf[1]]; + outbuf[2] = encode_table[outbuf[2]]; + outbuf[3] = encode_table[outbuf[3]]; + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),4)!=4) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + + // get 3 more input bytes + status = in.sgetn(reinterpret_cast(&inbuf),3); + continue; + } + else if (status == 2) + { + // we are at the end of the input stream and need to add some padding + + // encode the bytes in inbuf to base64 and write them to the output stream + c1 = inbuf[0]&0xfc; + c2 = inbuf[0]&0x03; + c3 = inbuf[1]&0xf0; + c4 = inbuf[1]&0x0f; + c5 = 0; + + outbuf[0] = c1>>2; + outbuf[1] = (c2<<4)|(c3>>4); + outbuf[2] = (c4<<2)|(c5>>6); + outbuf[3] = '='; + + outbuf[0] = encode_table[outbuf[0]]; + outbuf[1] = encode_table[outbuf[1]]; + outbuf[2] = encode_table[outbuf[2]]; + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),4)!=4) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + + + break; + } + else // in this case status must be 1 + { + // we are at the end of the input stream and need to add some padding + + // encode the bytes in inbuf to base64 and write them to the output stream + c1 = inbuf[0]&0xfc; + c2 = inbuf[0]&0x03; + c3 = 0; + + outbuf[0] = c1>>2; + outbuf[1] = (c2<<4)|(c3>>4); + outbuf[2] = '='; + outbuf[3] = '='; + + outbuf[0] = encode_table[outbuf[0]]; + outbuf[1] = encode_table[outbuf[1]]; + + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),4)!=4) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + + break; + } + } // while (status != 0) + + + // make sure the stream buffer flushes to its I/O channel + out.pubsync(); + } + +// ---------------------------------------------------------------------------------------- + + void base64:: + decode ( + std::istream& in_, + std::ostream& out_ + ) const + { + using namespace std; + streambuf& in = *in_.rdbuf(); + streambuf& out = *out_.rdbuf(); + + unsigned char inbuf[4]; + unsigned char outbuf[3]; + int inbuf_pos = 0; + streamsize status = in.sgetn(reinterpret_cast(inbuf),1); + + // only count this character if it isn't some kind of filler + if (status == 1 && decode_table[inbuf[0]] != bad_value ) + ++inbuf_pos; + + unsigned char c1, c2, c3, c4, c5, c6; + streamsize outsize; + + // while we haven't hit the end of the input stream + while (status != 0) + { + // if we have 4 valid characters + if (inbuf_pos == 4) + { + inbuf_pos = 0; + + // this might be the end of the encoded data so we need to figure out if + // there was any padding applied. + outsize = 3; + if (inbuf[3] == '=') + { + if (inbuf[2] == '=') + outsize = 1; + else + outsize = 2; + } + + // decode the incoming characters + inbuf[0] = decode_table[inbuf[0]]; + inbuf[1] = decode_table[inbuf[1]]; + inbuf[2] = decode_table[inbuf[2]]; + inbuf[3] = decode_table[inbuf[3]]; + + + // now pack these guys into bytes rather than 6 bit chunks + c1 = inbuf[0]<<2; + c2 = inbuf[1]>>4; + c3 = inbuf[1]<<4; + c4 = inbuf[2]>>2; + c5 = inbuf[2]<<6; + c6 = inbuf[3]; + + outbuf[0] = c1|c2; + outbuf[1] = c3|c4; + outbuf[2] = c5|c6; + + + // write the encoded bytes to the output stream + if (out.sputn(reinterpret_cast(&outbuf),outsize)!=outsize) + { + throw std::ios_base::failure("error occurred in the base64 object"); + } + } + + // get more input characters + status = in.sgetn(reinterpret_cast(inbuf + inbuf_pos),1); + // only count this character if it isn't some kind of filler + if ((decode_table[inbuf[inbuf_pos]] != bad_value || inbuf[inbuf_pos] == '=') && + status != 0) + ++inbuf_pos; + } // while (status != 0) + + if (inbuf_pos != 0) + { + ostringstream sout; + sout << inbuf_pos << " extra characters were found at the end of the encoded data." + << " This may indicate that the data stream has been truncated."; + // this happens if we hit EOF in the middle of decoding a 24bit block. + throw decode_error(sout.str()); + } + + // make sure the stream buffer flushes to its I/O channel + out.pubsync(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BASE64_KERNEL_1_CPp_ + diff --git a/dlib/base64/base64_kernel_1.h b/dlib/base64/base64_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..d8f49b1b8cad201499aead19ca759bec2fe8e919 --- /dev/null +++ b/dlib/base64/base64_kernel_1.h @@ -0,0 +1,92 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BASE64_KERNEl_1_ +#define DLIB_BASE64_KERNEl_1_ + +#include "../algs.h" +#include "base64_kernel_abstract.h" +#include + +namespace dlib +{ + + class base64 + { + /*! + INITIAL VALUE + - bad_value == 100 + - encode_table == a pointer to an array of 64 chars + - where x is a 6 bit value the following is true: + - encode_table[x] == the base64 encoding of x + - decode_table == a pointer to an array of UCHAR_MAX chars + - where x is any char value: + - if (x is a valid character in the base64 coding scheme) then + - decode_table[x] == the 6 bit value that x encodes + - else + - decode_table[x] == bad_value + + CONVENTION + - The state of this object never changes so just refer to its + initial value. + + + !*/ + + public: + // this is here for backwards compatibility with older versions of dlib. + typedef base64 kernel_1a; + + class decode_error : public dlib::error { public: + decode_error( const std::string& e) : error(e) {}}; + + base64 ( + ); + + virtual ~base64 ( + ); + + enum line_ending_type + { + CR, // i.e. "\r" + LF, // i.e. "\n" + CRLF // i.e. "\r\n" + }; + + line_ending_type line_ending ( + ) const; + + void set_line_ending ( + line_ending_type eol_style_ + ); + + void encode ( + std::istream& in, + std::ostream& out + ) const; + + void decode ( + std::istream& in, + std::ostream& out + ) const; + + private: + + char* encode_table; + unsigned char* decode_table; + const unsigned char bad_value; + line_ending_type eol_style; + + // restricted functions + base64(base64&); // copy constructor + base64& operator=(base64&); // assignment operator + + }; + +} + +#ifdef NO_MAKEFILE +#include "base64_kernel_1.cpp" +#endif + +#endif // DLIB_BASE64_KERNEl_1_ + diff --git a/dlib/base64/base64_kernel_abstract.h b/dlib/base64/base64_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..0a63d3b87be54f88dcfbf122850e81d9987034a2 --- /dev/null +++ b/dlib/base64/base64_kernel_abstract.h @@ -0,0 +1,121 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BASE64_KERNEl_ABSTRACT_ +#ifdef DLIB_BASE64_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include + +namespace dlib +{ + + class base64 + { + /*! + INITIAL VALUE + - line_ending() == LF + + WHAT THIS OBJECT REPRESENTS + This object consists of the two functions encode and decode. + These functions allow you to encode and decode data to and from + the Base64 Content-Transfer-Encoding defined in section 6.8 of + rfc2045. + !*/ + + public: + + class decode_error : public dlib::error {}; + + base64 ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~base64 ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + enum line_ending_type + { + CR, // i.e. "\r" + LF, // i.e. "\n" + CRLF // i.e. "\r\n" + }; + + line_ending_type line_ending ( + ) const; + /*! + ensures + - returns the type of end of line bytes the encoder + will use when encoding data to base64 blocks. Note that + the ostream object you use might apply some sort of transform + to line endings as well. For example, C++ ofstream objects + usually convert '\n' into whatever a normal newline is for + your platform unless you open a file in binary mode. But + aside from file streams the ostream objects usually don't + modify the data you pass to them. + !*/ + + void set_line_ending ( + line_ending_type eol_style + ); + /*! + ensures + - #line_ending() == eol_style + !*/ + + void encode ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads all data from in (until EOF is reached) and encodes it + and writes it to out + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + void decode ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads data from in (until EOF is reached), decodes it, + and writes it to out. + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - decode_error + if an error was detected in the encoded data that prevented + it from being correctly decoded then this exception is + thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + private: + + // restricted functions + base64(base64&); // copy constructor + base64& operator=(base64&); // assignment operator + + }; + +} + +#endif // DLIB_BASE64_KERNEl_ABSTRACT_ + diff --git a/dlib/bayes_utils.h b/dlib/bayes_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..51ef6d2ed3d02cd1fe56b0bb2a6d2a7f7682e8ff --- /dev/null +++ b/dlib/bayes_utils.h @@ -0,0 +1,11 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BAYES_UTILs_H_ +#define DLIB_BAYES_UTILs_H_ + +#include "bayes_utils/bayes_utils.h" + +#endif // DLIB_BAYES_UTILs_H_ + + + diff --git a/dlib/bayes_utils/bayes_utils.h b/dlib/bayes_utils/bayes_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..c5ebc19b92a71e836ed787ad054dd376139f23e3 --- /dev/null +++ b/dlib/bayes_utils/bayes_utils.h @@ -0,0 +1,1677 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BAYES_UTILs_ +#define DLIB_BAYES_UTILs_ + +#include "bayes_utils_abstract.h" + +#include +#include +#include +#include + +#include "../string.h" +#include "../map.h" +#include "../matrix.h" +#include "../rand.h" +#include "../array.h" +#include "../set.h" +#include "../algs.h" +#include "../noncopyable.h" +#include "../graph.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class assignment + { + public: + + assignment() + { + } + + assignment( + const assignment& a + ) + { + a.reset(); + while (a.move_next()) + { + unsigned long idx = a.element().key(); + unsigned long value = a.element().value(); + vals.add(idx,value); + } + } + + assignment& operator = ( + const assignment& rhs + ) + { + if (this == &rhs) + return *this; + + assignment(rhs).swap(*this); + return *this; + } + + void clear() + { + vals.clear(); + } + + bool operator < ( + const assignment& item + ) const + { + if (size() < item.size()) + return true; + else if (size() > item.size()) + return false; + + reset(); + item.reset(); + while (move_next()) + { + item.move_next(); + if (element().key() < item.element().key()) + return true; + else if (element().key() > item.element().key()) + return false; + else if (element().value() < item.element().value()) + return true; + else if (element().value() > item.element().value()) + return false; + } + + return false; + } + + bool has_index ( + unsigned long idx + ) const + { + return vals.is_in_domain(idx); + } + + void add ( + unsigned long idx, + unsigned long value = 0 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == false , + "\tvoid assignment::add(idx)" + << "\n\tYou can't add the same index to an assignment object more than once" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + vals.add(idx, value); + } + + unsigned long& operator[] ( + const long idx + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == true , + "\tunsigned long assignment::operator[](idx)" + << "\n\tYou can't access an index value if it isn't already in the object" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + return vals[idx]; + } + + const unsigned long& operator[] ( + const long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == true , + "\tunsigned long assignment::operator[](idx)" + << "\n\tYou can't access an index value if it isn't already in the object" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + return vals[idx]; + } + + void swap ( + assignment& item + ) + { + vals.swap(item.vals); + } + + void remove ( + unsigned long idx + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( has_index(idx) == true , + "\tunsigned long assignment::remove(idx)" + << "\n\tYou can't remove an index value if it isn't already in the object" + << "\n\tidx: " << idx + << "\n\tthis: " << this + ); + + vals.destroy(idx); + } + + unsigned long size() const { return vals.size(); } + + void reset() const { vals.reset(); } + + bool move_next() const { return vals.move_next(); } + + map_pair& element() + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tmap_pair& assignment::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + return vals.element(); + } + + const map_pair& element() const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tconst map_pair& assignment::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return vals.element(); + } + + bool at_start() const { return vals.at_start(); } + + bool current_element_valid() const { return vals.current_element_valid(); } + + friend inline void serialize ( + const assignment& item, + std::ostream& out + ) + { + serialize(item.vals, out); + } + + friend inline void deserialize ( + assignment& item, + std::istream& in + ) + { + deserialize(item.vals, in); + } + + private: + mutable dlib::map::kernel_1b_c vals; + }; + + inline std::ostream& operator << ( + std::ostream& out, + const assignment& a + ) + { + a.reset(); + out << "("; + if (a.move_next()) + out << a.element().key() << ":" << a.element().value(); + + while (a.move_next()) + { + out << ", " << a.element().key() << ":" << a.element().value(); + } + + out << ")"; + return out; + } + + + inline void swap ( + assignment& a, + assignment& b + ) + { + a.swap(b); + } + + +// ------------------------------------------------------------------------ + + class joint_probability_table + { + /*! + INITIAL VALUE + - table.size() == 0 + + CONVENTION + - size() == table.size() + - probability(a) == table[a] + !*/ + public: + + joint_probability_table ( + const joint_probability_table& t + ) + { + t.reset(); + while (t.move_next()) + { + assignment a = t.element().key(); + double p = t.element().value(); + set_probability(a,p); + } + } + + joint_probability_table() {} + + joint_probability_table& operator= ( + const joint_probability_table& rhs + ) + { + if (this == &rhs) + return *this; + joint_probability_table(rhs).swap(*this); + return *this; + } + + void set_probability ( + const assignment& a, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0.0 <= p && p <= 1.0, + "\tvoid& joint_probability_table::set_probability(a,p)" + << "\n\tyou have given an invalid probability value" + << "\n\tp: " << p + << "\n\ta: " << a + << "\n\tthis: " << this + ); + + if (table.is_in_domain(a)) + { + table[a] = p; + } + else + { + assignment temp(a); + table.add(temp,p); + } + } + + bool has_entry_for ( + const assignment& a + ) const + { + return table.is_in_domain(a); + } + + void add_probability ( + const assignment& a, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(0.0 <= p && p <= 1.0, + "\tvoid& joint_probability_table::add_probability(a,p)" + << "\n\tyou have given an invalid probability value" + << "\n\tp: " << p + << "\n\ta: " << a + << "\n\tthis: " << this + ); + + if (table.is_in_domain(a)) + { + table[a] += p; + if (table[a] > 1.0) + table[a] = 1.0; + } + else + { + assignment temp(a); + table.add(temp,p); + } + } + + double probability ( + const assignment& a + ) const + { + return table[a]; + } + + void clear() + { + table.clear(); + } + + size_t size () const { return table.size(); } + bool move_next() const { return table.move_next(); } + void reset() const { table.reset(); } + map_pair& element() + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tmap_pair& joint_probability_table::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return table.element(); + } + + const map_pair& element() const + { + // make sure requires clause is not broken + DLIB_ASSERT(current_element_valid() == true, + "\tconst map_pair& joint_probability_table::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return table.element(); + } + + bool at_start() const { return table.at_start(); } + + bool current_element_valid() const { return table.current_element_valid(); } + + + template + void marginalize ( + const T& vars, + joint_probability_table& out + ) const + { + out.clear(); + double p; + reset(); + while (move_next()) + { + assignment a; + const assignment& asrc = element().key(); + p = element().value(); + + asrc.reset(); + while (asrc.move_next()) + { + if (vars.is_member(asrc.element().key())) + a.add(asrc.element().key(), asrc.element().value()); + } + + out.add_probability(a,p); + } + } + + void marginalize ( + const unsigned long var, + joint_probability_table& out + ) const + { + out.clear(); + double p; + reset(); + while (move_next()) + { + assignment a; + const assignment& asrc = element().key(); + p = element().value(); + + asrc.reset(); + while (asrc.move_next()) + { + if (var == asrc.element().key()) + a.add(asrc.element().key(), asrc.element().value()); + } + + out.add_probability(a,p); + } + } + + void normalize ( + ) + { + double sum = 0; + + reset(); + while (move_next()) + sum += element().value(); + + reset(); + while (move_next()) + element().value() /= sum; + } + + void swap ( + joint_probability_table& item + ) + { + table.swap(item.table); + } + + friend inline void serialize ( + const joint_probability_table& item, + std::ostream& out + ) + { + serialize(item.table, out); + } + + friend inline void deserialize ( + joint_probability_table& item, + std::istream& in + ) + { + deserialize(item.table, in); + } + + private: + + dlib::map::kernel_1b_c table; + }; + + inline void swap ( + joint_probability_table& a, + joint_probability_table& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + class conditional_probability_table : noncopyable + { + /*! + INITIAL VALUE + - table.size() == 0 + + CONVENTION + - if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) then + - has_entry_for(value,ps) == true + - probability(value,ps) == table[ps](value) + - else + - has_entry_for(value,ps) == false + + - num_values() == num_vals + !*/ + public: + + conditional_probability_table() + { + clear(); + } + + void set_num_values ( + unsigned long num + ) + { + num_vals = num; + table.clear(); + } + + bool has_entry_for ( + unsigned long value, + const assignment& ps + ) const + { + if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) + return true; + else + return false; + } + + unsigned long num_values ( + ) const { return num_vals; } + + void set_probability ( + unsigned long value, + const assignment& ps, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( value < num_values() && 0.0 <= p && p <= 1.0 , + "\tvoid conditional_probability_table::set_probability()" + << "\n\tinvalid arguments to set_probability" + << "\n\tvalue: " << value + << "\n\tnum_values(): " << num_values() + << "\n\tp: " << p + << "\n\tps: " << ps + << "\n\tthis: " << this + ); + + if (table.is_in_domain(ps)) + { + table[ps](value) = p; + } + else + { + matrix dist(num_vals); + set_all_elements(dist,-1); + dist(value) = p; + assignment temp(ps); + table.add(temp,dist); + } + } + + double probability( + unsigned long value, + const assignment& ps + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( value < num_values() && has_entry_for(value,ps) , + "\tvoid conditional_probability_table::probability()" + << "\n\tinvalid arguments to probability" + << "\n\tvalue: " << value + << "\n\tnum_values(): " << num_values() + << "\n\tps: " << ps + << "\n\tthis: " << this + ); + + return table[ps](value); + } + + void clear() + { + table.clear(); + num_vals = 0; + } + + void empty_table () + { + table.clear(); + } + + void swap ( + conditional_probability_table& item + ) + { + exchange(num_vals, item.num_vals); + table.swap(item.table); + } + + friend inline void serialize ( + const conditional_probability_table& item, + std::ostream& out + ) + { + serialize(item.table, out); + serialize(item.num_vals, out); + } + + friend inline void deserialize ( + conditional_probability_table& item, + std::istream& in + ) + { + deserialize(item.table, in); + deserialize(item.num_vals, in); + } + + private: + dlib::map >::kernel_1b_c table; + unsigned long num_vals; + }; + + inline void swap ( + conditional_probability_table& a, + conditional_probability_table& b + ) { a.swap(b); } + +// ------------------------------------------------------------------------ + + class bayes_node : noncopyable + { + public: + bayes_node () + { + is_instantiated = false; + value_ = 0; + } + + unsigned long value ( + ) const { return value_;} + + void set_value ( + unsigned long new_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( new_value < table().num_values(), + "\tvoid bayes_node::set_value(new_value)" + << "\n\tnew_value must be less than the number of possible values for this node" + << "\n\tnew_value: " << new_value + << "\n\ttable().num_values(): " << table().num_values() + << "\n\tthis: " << this + ); + + value_ = new_value; + } + + conditional_probability_table& table ( + ) { return table_; } + + const conditional_probability_table& table ( + ) const { return table_; } + + bool is_evidence ( + ) const { return is_instantiated; } + + void set_as_nonevidence ( + ) { is_instantiated = false; } + + void set_as_evidence ( + ) { is_instantiated = true; } + + void swap ( + bayes_node& item + ) + { + exchange(value_, item.value_); + exchange(is_instantiated, item.is_instantiated); + table_.swap(item.table_); + } + + friend inline void serialize ( + const bayes_node& item, + std::ostream& out + ) + { + serialize(item.value_, out); + serialize(item.is_instantiated, out); + serialize(item.table_, out); + } + + friend inline void deserialize ( + bayes_node& item, + std::istream& in + ) + { + deserialize(item.value_, in); + deserialize(item.is_instantiated, in); + deserialize(item.table_, in); + } + + private: + + unsigned long value_; + bool is_instantiated; + conditional_probability_table table_; + }; + + inline void swap ( + bayes_node& a, + bayes_node& b + ) { a.swap(b); } + +// ------------------------------------------------------------------------ + + namespace bayes_node_utils + { + + template + unsigned long node_num_values ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::node_num_values(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + return bn.node(n).data.table().num_values(); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_value ( + T& bn, + unsigned long n, + unsigned long val + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes() && val < node_num_values(bn,n), + "\tvoid bayes_node_utils::set_node_value(bn, n, val)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tval: " << val + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) + ); + + bn.node(n).data.set_value(val); + } + + // ---------------------------------------------------------------------------------------- + template + unsigned long node_value ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tunsigned long bayes_node_utils::node_value(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + return bn.node(n).data.value(); + } + // ---------------------------------------------------------------------------------------- + + template + bool node_is_evidence ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tbool bayes_node_utils::node_is_evidence(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + return bn.node(n).data.is_evidence(); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_as_evidence ( + T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::set_node_as_evidence(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + bn.node(n).data.set_as_evidence(); + } + + // ---------------------------------------------------------------------------------------- + template + void set_node_as_nonevidence ( + T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::set_node_as_nonevidence(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + bn.node(n).data.set_as_nonevidence(); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_num_values ( + T& bn, + unsigned long n, + unsigned long num + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tvoid bayes_node_utils::set_node_num_values(bn, n, num)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + bn.node(n).data.table().set_num_values(num); + } + + // ---------------------------------------------------------------------------------------- + + template + double node_probability ( + const T& bn, + unsigned long n, + unsigned long value, + const assignment& parents + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tvalue: " << value + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) + ); + + DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tparents.size(): " << parents.size() + << "\n\tb.node(n).number_of_parents(): " << bn.node(n).number_of_parents() + ); + +#ifdef ENABLE_ASSERTS + parents.reset(); + while (parents.move_next()) + { + const unsigned long x = parents.element().key(); + DLIB_ASSERT( bn.has_edge(x, n), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + ); + DLIB_ASSERT( parents[x] < node_num_values(bn,x), + "\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + << "\n\tparents[x]: " << parents[x] + << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) + ); + } +#endif + + return bn.node(n).data.table().probability(value, parents); + } + + // ---------------------------------------------------------------------------------------- + + template + void set_node_probability ( + T& bn, + unsigned long n, + unsigned long value, + const assignment& parents, + double p + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tp: " << p + << "\n\tvalue: " << value + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + << "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) + ); + + DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tp: " << p + << "\n\tparents.size(): " << parents.size() + << "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents() + ); + + DLIB_ASSERT( 0.0 <= p && p <= 1.0, + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tp: " << p + ); + +#ifdef ENABLE_ASSERTS + parents.reset(); + while (parents.move_next()) + { + const unsigned long x = parents.element().key(); + DLIB_ASSERT( bn.has_edge(x, n), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + ); + DLIB_ASSERT( parents[x] < node_num_values(bn,x), + "\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + << "\n\tparents[x]: " << parents[x] + << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) + ); + } +#endif + + bn.node(n).data.table().set_probability(value,parents,p); + } + +// ---------------------------------------------------------------------------------------- + + template + const assignment node_first_parent_assignment ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tconst assignment bayes_node_utils::node_first_parent_assignment(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + ); + + assignment a; + const unsigned long num_parents = bn.node(n).number_of_parents(); + for (unsigned long i = 0; i < num_parents; ++i) + { + a.add(bn.node(n).parent(i).index(), 0); + } + return a; + } + +// ---------------------------------------------------------------------------------------- + + template + bool node_next_parent_assignment ( + const T& bn, + unsigned long n, + assignment& a + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + ); + + DLIB_ASSERT( a.size() == bn.node(n).number_of_parents(), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\ta.size(): " << a.size() + << "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents() + ); + +#ifdef ENABLE_ASSERTS + a.reset(); + while (a.move_next()) + { + const unsigned long x = a.element().key(); + DLIB_ASSERT( bn.has_edge(x, n), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + ); + DLIB_ASSERT( a[x] < node_num_values(bn,x), + "\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tx: " << x + << "\n\ta[x]: " << a[x] + << "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) + ); + } +#endif + + // basically this loop just adds 1 to the assignment but performs + // carries if necessary + for (unsigned long p = 0; p < a.size(); ++p) + { + const unsigned long pindex = bn.node(n).parent(p).index(); + a[pindex] += 1; + + // if we need to perform a carry + if (a[pindex] >= node_num_values(bn,pindex)) + { + a[pindex] = 0; + } + else + { + // no carry necessary so we are done + return true; + } + } + + // we got through the entire loop which means a carry propagated all the way out + // so there must not be any more valid assignments left + return false; + } + +// ---------------------------------------------------------------------------------------- + + template + bool node_cpt_filled_out ( + const T& bn, + unsigned long n + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( n < bn.number_of_nodes(), + "\tbool bayes_node_utils::node_cpt_filled_out(bn, n)" + << "\n\tInvalid arguments to this function" + << "\n\tn: " << n + << "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() + ); + + const unsigned long num_values = node_num_values(bn,n); + + + const conditional_probability_table& table = bn.node(n).data.table(); + + // now loop over all the possible parent assignments for this node + assignment a(node_first_parent_assignment(bn,n)); + do + { + double sum = 0; + // make sure that this assignment has an entry for all the values this node can take one + for (unsigned long value = 0; value < num_values; ++value) + { + if (table.has_entry_for(value,a) == false) + return false; + else + sum += table.probability(value,a); + } + + // check if the sum of probabilities equals 1 as it should + if (std::abs(sum-1.0) > 1e-5) + return false; + } while (node_next_parent_assignment(bn,n,a)); + + return true; + } + + } + +// ---------------------------------------------------------------------------------------- + + class bayesian_network_gibbs_sampler : noncopyable + { + public: + + bayesian_network_gibbs_sampler () + { + rnd.set_seed(cast_to_string(std::time(0))); + } + + + template < + typename T + > + void sample_graph ( + T& bn + ) + { + using namespace bayes_node_utils; + for (unsigned long n = 0; n < bn.number_of_nodes(); ++n) + { + if (node_is_evidence(bn, n)) + continue; + + samples.set_size(node_num_values(bn,n)); + // obtain the probability distribution for this node + for (long i = 0; i < samples.nc(); ++i) + { + set_node_value(bn, n, i); + samples(i) = node_probability(bn, n); + + for (unsigned long j = 0; j < bn.node(n).number_of_children(); ++j) + samples(i) *= node_probability(bn, bn.node(n).child(j).index()); + } + + //normalize samples + samples /= sum(samples); + + + // select a random point in the probability distribution + double prob = rnd.get_random_double(); + + // now find the point in the distribution this probability corresponds to + long j; + for (j = 0; j < samples.nc()-1; ++j) + { + if (prob <= samples(j)) + break; + else + prob -= samples(j); + } + + set_node_value(bn, n, j); + } + } + + + private: + + template < + typename T + > + double node_probability ( + const T& bn, + unsigned long n + ) + /*! + requires + - n < bn.number_of_nodes() + ensures + - computes the probability of node n having its current value given + the current values of its parents in the network bn + !*/ + { + v.clear(); + for (unsigned long i = 0; i < bn.node(n).number_of_parents(); ++i) + { + v.add(bn.node(n).parent(i).index(), bn.node(n).parent(i).data.value()); + } + return bn.node(n).data.table().probability(bn.node(n).data.value(), v); + } + + assignment v; + + dlib::rand rnd; + matrix samples; + }; + +// ---------------------------------------------------------------------------------------- + + namespace bayesian_network_join_tree_helpers + { + class bnjt + { + /*! + this object is the base class used in this pimpl idiom + !*/ + public: + virtual ~bnjt() {} + + virtual const matrix probability( + unsigned long idx + ) const = 0; + }; + + template + class bnjt_impl : public bnjt + { + /*! + This object is the implementation in the pimpl idiom + !*/ + + public: + + bnjt_impl ( + const T& bn, + const U& join_tree + ) + { + create_bayesian_network_join_tree(bn, join_tree, join_tree_values); + + cliques.resize(bn.number_of_nodes()); + + // figure out which cliques contain each node + for (unsigned long i = 0; i < cliques.size(); ++i) + { + // find the smallest clique that contains node with index i + unsigned long smallest_clique = 0; + unsigned long size = std::numeric_limits::max(); + + for (unsigned long n = 0; n < join_tree.number_of_nodes(); ++n) + { + if (join_tree.node(n).data.is_member(i) && join_tree.node(n).data.size() < size) + { + size = join_tree.node(n).data.size(); + smallest_clique = n; + } + } + + cliques[i] = smallest_clique; + } + } + + virtual const matrix probability( + unsigned long idx + ) const + { + join_tree_values.node(cliques[idx]).data.marginalize(idx, table); + table.normalize(); + var.clear(); + var.add(idx); + dist.set_size(table.size()); + + // read the probabilities out of the table and into the row matrix + for (unsigned long i = 0; i < table.size(); ++i) + { + var[idx] = i; + dist(i) = table.probability(var); + } + + return dist; + } + + private: + + graph< joint_probability_table, joint_probability_table >::kernel_1a_c join_tree_values; + array cliques; + mutable joint_probability_table table; + mutable assignment var; + mutable matrix dist; + + + // ---------------------------------------------------------------------------------------- + + template + bool set_contains_all_parents_of_node ( + const set_type& set, + const node_type& node + ) + { + for (unsigned long i = 0; i < node.number_of_parents(); ++i) + { + if (set.is_member(node.parent(i).index()) == false) + return false; + } + return true; + } + + // ---------------------------------------------------------------------------------------- + + template < + typename V + > + void pass_join_tree_message ( + const U& join_tree, + V& bn_join_tree , + unsigned long from, + unsigned long to + ) + { + using namespace bayes_node_utils; + const typename U::edge_type& e = edge(join_tree, from, to); + typename V::edge_type& old_s = edge(bn_join_tree, from, to); + + typedef typename V::edge_type joint_prob_table; + + joint_prob_table new_s; + bn_join_tree.node(from).data.marginalize(e, new_s); + + joint_probability_table temp(new_s); + // divide new_s by old_s and store the result in temp. + // if old_s is empty then that is the same as if it was all 1s + // so we don't have to do this if that is the case. + if (old_s.size() > 0) + { + temp.reset(); + old_s.reset(); + while (temp.move_next()) + { + old_s.move_next(); + if (old_s.element().value() != 0) + temp.element().value() /= old_s.element().value(); + } + } + + // now multiply temp by d and store the results in d + joint_probability_table& d = bn_join_tree.node(to).data; + d.reset(); + while (d.move_next()) + { + assignment a; + const assignment& asrc = d.element().key(); + asrc.reset(); + while (asrc.move_next()) + { + if (e.is_member(asrc.element().key())) + a.add(asrc.element().key(), asrc.element().value()); + } + + d.element().value() *= temp.probability(a); + + } + + // store new_s in old_s + new_s.swap(old_s); + + } + + // ---------------------------------------------------------------------------------------- + + template < + typename V + > + void create_bayesian_network_join_tree ( + const T& bn, + const U& join_tree, + V& bn_join_tree + ) + /*! + requires + - bn is a proper bayesian network + - join_tree is the join tree for that bayesian network + ensures + - bn_join_tree == the output of the join tree algorithm for bayesian network inference. + So each node in this graph contains a joint_probability_table for the clique + in the corresponding node in the join_tree graph. + !*/ + { + using namespace bayes_node_utils; + bn_join_tree.clear(); + copy_graph_structure(join_tree, bn_join_tree); + + // we need to keep track of which node is "in" each clique for the purposes of + // initializing the tables in each clique. So this vector will be used to do that + // and a value of join_tree.number_of_nodes() means that the node with + // that index is unassigned. + std::vector node_assigned_to(bn.number_of_nodes(),join_tree.number_of_nodes()); + + // populate evidence with all the evidence node indices and their values + dlib::map::kernel_1b_c evidence; + for (unsigned long i = 0; i < bn.number_of_nodes(); ++i) + { + if (node_is_evidence(bn, i)) + { + unsigned long idx = i; + unsigned long value = node_value(bn, i); + evidence.add(idx,value); + } + } + + + // initialize the bn join tree + for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i) + { + bool contains_evidence = false; + std::vector indices; + assignment value; + + // loop over all the nodes in this clique in the join tree. In this loop + // we are making an assignment with all the values of the nodes it represents set to 0 + join_tree.node(i).data.reset(); + while (join_tree.node(i).data.move_next()) + { + const unsigned long idx = join_tree.node(i).data.element(); + indices.push_back(idx); + value.add(idx); + + if (evidence.is_in_domain(join_tree.node(i).data.element())) + contains_evidence = true; + } + + // now loop over all possible combinations of values that the nodes this + // clique in the join tree can take on. We do this by counting by one through all + // legal values + bool more_assignments = true; + while (more_assignments) + { + bn_join_tree.node(i).data.set_probability(value,1); + + // account for any evidence + if (contains_evidence) + { + // loop over all the nodes in this cluster + for (unsigned long j = 0; j < indices.size(); ++j) + { + // if the current node is an evidence node + if (evidence.is_in_domain(indices[j])) + { + const unsigned long idx = indices[j]; + const unsigned long evidence_value = evidence[idx]; + if (value[idx] != evidence_value) + bn_join_tree.node(i).data.set_probability(value , 0); + } + } + } + + + // now check if any of the nodes in this cluster also have their parents in this cluster + join_tree.node(i).data.reset(); + while (join_tree.node(i).data.move_next()) + { + const unsigned long idx = join_tree.node(i).data.element(); + // if this clique contains all the parents of this node and also hasn't + // been assigned to another clique + if (set_contains_all_parents_of_node(join_tree.node(i).data, bn.node(idx)) && + (i == node_assigned_to[idx] || node_assigned_to[idx] == join_tree.number_of_nodes()) ) + { + // note that this node is now assigned to this clique + node_assigned_to[idx] = i; + // node idx has all its parents in the cluster + assignment parent_values; + for (unsigned long j = 0; j < bn.node(idx).number_of_parents(); ++j) + { + const unsigned long pidx = bn.node(idx).parent(j).index(); + parent_values.add(pidx, value[pidx]); + } + + double temp = bn_join_tree.node(i).data.probability(value); + bn_join_tree.node(i).data.set_probability(value, temp * node_probability(bn, idx, value[idx], parent_values)); + + } + } + + + // now advance the value variable to its next possible state if there is one + more_assignments = false; + value.reset(); + while (value.move_next()) + { + value.element().value() += 1; + // if overflow + if (value.element().value() == node_num_values(bn, value.element().key())) + { + value.element().value() = 0; + } + else + { + more_assignments = true; + break; + } + } + + } // end while (more_assignments) + } + + + + + // the tree is now initialized. Now all we need to do is perform the propagation and + // we are done + dlib::array::compare_1b_c> remaining_msg_to_send; + dlib::array::compare_1b_c> remaining_msg_to_receive; + remaining_msg_to_receive.resize(join_tree.number_of_nodes()); + remaining_msg_to_send.resize(join_tree.number_of_nodes()); + for (unsigned long i = 0; i < remaining_msg_to_receive.size(); ++i) + { + for (unsigned long j = 0; j < join_tree.node(i).number_of_neighbors(); ++j) + { + const unsigned long idx = join_tree.node(i).neighbor(j).index(); + unsigned long temp; + temp = idx; remaining_msg_to_receive[i].add(temp); + temp = idx; remaining_msg_to_send[i].add(temp); + } + } + + // now remaining_msg_to_receive[i] contains all the nodes that node i hasn't yet received + // a message from. + // we will consider node 0 to be the root node. + + + bool message_sent = true; + while (message_sent) + { + message_sent = false; + for (unsigned long i = 1; i < remaining_msg_to_send.size(); ++i) + { + // if node i hasn't sent any messages but has received all but one then send a message to the one + // node who hasn't sent i a message + if (remaining_msg_to_send[i].size() == join_tree.node(i).number_of_neighbors() && remaining_msg_to_receive[i].size() == 1) + { + unsigned long to; + // get the last remaining thing from this set + remaining_msg_to_receive[i].remove_any(to); + + // send the message + pass_join_tree_message(join_tree, bn_join_tree, i, to); + + // record that we sent this message + remaining_msg_to_send[i].destroy(to); + remaining_msg_to_receive[to].destroy(i); + + // put to back in since we still need to receive it + remaining_msg_to_receive[i].add(to); + message_sent = true; + } + else if (remaining_msg_to_receive[i].size() == 0 && remaining_msg_to_send[i].size() > 0) + { + unsigned long to; + remaining_msg_to_send[i].remove_any(to); + remaining_msg_to_receive[to].destroy(i); + pass_join_tree_message(join_tree, bn_join_tree, i, to); + message_sent = true; + } + } + + if (remaining_msg_to_receive[0].size() == 0) + { + // send a message to all of the root nodes neighbors unless we have already sent out he messages + while (remaining_msg_to_send[0].size() > 0) + { + unsigned long to; + remaining_msg_to_send[0].remove_any(to); + remaining_msg_to_receive[to].destroy(0); + pass_join_tree_message(join_tree, bn_join_tree, 0, to); + message_sent = true; + } + } + + + } + + } + + }; + } + + class bayesian_network_join_tree : noncopyable + { + /*! + use the pimpl idiom to push the template arguments from the class level to the + constructor level + !*/ + + public: + + template < + typename T, + typename U + > + bayesian_network_join_tree ( + const T& bn, + const U& join_tree + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( bn.number_of_nodes() > 0 , + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network" + << "\n\tthis: " << this + ); + + DLIB_ASSERT( is_join_tree(bn, join_tree) == true , + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid join tree for the supplied bayesian network" + << "\n\tthis: " << this + ); + DLIB_ASSERT( graph_contains_length_one_cycle(bn) == false, + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network" + << "\n\tthis: " << this + ); + DLIB_ASSERT( graph_is_connected(bn) == true, + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network" + << "\n\tthis: " << this + ); + +#ifdef ENABLE_ASSERTS + for (unsigned long i = 0; i < bn.number_of_nodes(); ++i) + { + DLIB_ASSERT(bayes_node_utils::node_cpt_filled_out(bn,i) == true, + "\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" + << "\n\tYou have given an invalid bayesian network. " + << "\n\tYou must finish filling out the conditional_probability_table of node " << i + << "\n\tthis: " << this + ); + } +#endif + + impl.reset(new bayesian_network_join_tree_helpers::bnjt_impl(bn, join_tree)); + num_nodes = bn.number_of_nodes(); + } + + const matrix probability( + unsigned long idx + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT( idx < number_of_nodes() , + "\tconst matrix bayesian_network_join_tree::probability(idx)" + << "\n\tYou have specified an invalid node index" + << "\n\tidx: " << idx + << "\n\tnumber_of_nodes(): " << number_of_nodes() + << "\n\tthis: " << this + ); + + return impl->probability(idx); + } + + unsigned long number_of_nodes ( + ) const { return num_nodes; } + + void swap ( + bayesian_network_join_tree& item + ) + { + exchange(num_nodes, item.num_nodes); + impl.swap(item.impl); + } + + private: + + std::unique_ptr impl; + unsigned long num_nodes; + + }; + + inline void swap ( + bayesian_network_join_tree& a, + bayesian_network_join_tree& b + ) { a.swap(b); } + +} + +// ---------------------------------------------------------------------------------------- + +#endif // DLIB_BAYES_UTILs_ + diff --git a/dlib/bayes_utils/bayes_utils_abstract.h b/dlib/bayes_utils/bayes_utils_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..b19e6e1da1844ccc9cf09a6139866822d1971a0f --- /dev/null +++ b/dlib/bayes_utils/bayes_utils_abstract.h @@ -0,0 +1,1042 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BAYES_UTILs_ABSTRACT_ +#ifdef DLIB_BAYES_UTILs_ABSTRACT_ + +#include "../algs.h" +#include "../noncopyable.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/map_pair.h" +#include "../serialize.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class assignment : public enumerable > + { + /*! + INITIAL VALUE + - size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the entries in the assignment in + ascending order according to index values. (i.e. the elements are + enumerated in sorted order according to the value of their keys) + + WHAT THIS OBJECT REPRESENTS + This object models an assignment of random variables to particular values. + It is used with the joint_probability_table and conditional_probability_table + objects to represent assignments of various random variables to actual values. + + So for example, if you had a joint_probability_table that represented the + following table: + P(A = 0, B = 0) = 0.2 + P(A = 0, B = 1) = 0.3 + P(A = 1, B = 0) = 0.1 + P(A = 1, B = 1) = 0.4 + + Also lets define an enum so we have concrete index numbers for A and B + enum { A = 0, B = 1}; + + Then you could query the value of P(A=1, B=0) as follows: + assignment a; + a.set(A, 1); + a.set(B, 0); + // and now it is the case that: + table.probability(a) == 0.1 + a[A] == 1 + a[B] == 0 + + + Also note that when enumerating the elements of an assignment object + the key() refers to the index and the value() refers to the value at that + index. For example: + + // assume a is an assignment object + a.reset(); + while (a.move_next()) + { + // in this loop it is always the case that: + // a[a.element().key()] == a.element().value() + } + !*/ + + public: + + assignment( + ); + /*! + ensures + - this object is properly initialized + !*/ + + assignment( + const assignment& a + ); + /*! + ensures + - #*this is a copy of a + !*/ + + assignment& operator = ( + const assignment& rhs + ); + /*! + ensures + - #*this is a copy of rhs + - returns *this + !*/ + + void clear( + ); + /*! + ensures + - this object has been returned to its initial value + !*/ + + bool operator < ( + const assignment& item + ) const; + /*! + ensures + - The exact functioning of this operator is undefined. The only guarantee + is that it establishes a total ordering on all possible assignment objects. + In other words, this operator makes it so that you can use assignment + objects in the associative containers but otherwise isn't of any + particular use. + !*/ + + bool has_index ( + unsigned long idx + ) const; + /*! + ensures + - if (this assignment object has an entry for index idx) then + - returns true + - else + - returns false + !*/ + + void add ( + unsigned long idx, + unsigned long value = 0 + ); + /*! + requires + - has_index(idx) == false + ensures + - #has_index(idx) == true + - #(*this)[idx] == value + !*/ + + void remove ( + unsigned long idx + ); + /*! + requires + - has_index(idx) == true + ensures + - #has_index(idx) == false + !*/ + + unsigned long& operator[] ( + const long idx + ); + /*! + requires + - has_index(idx) == true + ensures + - returns a reference to the value associated with index idx + !*/ + + const unsigned long& operator[] ( + const long idx + ) const; + /*! + requires + - has_index(idx) == true + ensures + - returns a const reference to the value associated with index idx + !*/ + + void swap ( + assignment& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + assignment& a, + assignment& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + std::ostream& operator << ( + std::ostream& out, + const assignment& a + ); + /*! + ensures + - writes a to the given output stream in the following format: + (index1:value1, index2:value2, ..., indexN:valueN) + !*/ + + void serialize ( + const assignment& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + assignment& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ------------------------------------------------------------------------ + + class joint_probability_table : public enumerable > + { + /*! + INITIAL VALUE + - size() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the entries in the probability table + in no particular order but they will all be visited. + + WHAT THIS OBJECT REPRESENTS + This object models a joint probability table. That is, it models + the function p(X). So this object models the probability of a particular + set of variables (referred to as X). + !*/ + + public: + + joint_probability_table( + ); + /*! + ensures + - this object is properly initialized + !*/ + + joint_probability_table ( + const joint_probability_table& t + ); + /*! + ensures + - this object is a copy of t + !*/ + + void clear( + ); + /*! + ensures + - this object has its initial value + !*/ + + joint_probability_table& operator= ( + const joint_probability_table& rhs + ); + /*! + ensures + - this object is a copy of rhs + - returns a reference to *this + !*/ + + bool has_entry_for ( + const assignment& a + ) const; + /*! + ensures + - if (this joint_probability_table has an entry for p(X = a)) then + - returns true + - else + - returns false + !*/ + + void set_probability ( + const assignment& a, + double p + ); + /*! + requires + - 0 <= p <= 1 + ensures + - if (has_entry_for(a) == false) then + - #size() == size() + 1 + - #probability(a) == p + - #has_entry_for(a) == true + !*/ + + void add_probability ( + const assignment& a, + double p + ); + /*! + requires + - 0 <= p <= 1 + ensures + - if (has_entry_for(a) == false) then + - #size() == size() + 1 + - #probability(a) == p + - else + - #probability(a) == min(probability(a) + p, 1.0) + (i.e. does a saturating add) + - #has_entry_for(a) == true + !*/ + + const double probability ( + const assignment& a + ) const; + /*! + ensures + - returns the probability p(X == a) + !*/ + + template < + typename T + > + void marginalize ( + const T& vars, + joint_probability_table& output_table + ) const; + /*! + requires + - T is an implementation of set/set_kernel_abstract.h + ensures + - marginalizes *this by summing over all variables not in vars. The + result is stored in output_table. + !*/ + + void marginalize ( + const unsigned long var, + joint_probability_table& output_table + ) const; + /*! + ensures + - is identical to calling the above marginalize() function with a set + that contains only var. Or in other words, performs a marginalization + with just one variable var. So that output_table will contain a table giving + the marginal probability of var all by itself. + !*/ + + void normalize ( + ); + /*! + ensures + - let sum == the sum of all the probabilities in this table + - after normalize() has finished it will be the case that the sum of all + the entries in this table is 1.0. This is accomplished by dividing all + the entries by the sum described above. + !*/ + + void swap ( + joint_probability_table& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + joint_probability_table& a, + joint_probability_table& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + void serialize ( + const joint_probability_table& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + joint_probability_table& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- + + class conditional_probability_table : noncopyable + { + /*! + INITIAL VALUE + - num_values() == 0 + - has_value_for(x, y) == false for all values of x and y + + WHAT THIS OBJECT REPRESENTS + This object models a conditional probability table. That is, it models + the function p( X | parents). So this object models the conditional + probability of a particular variable (referred to as X) given another set + of variables (referred to as parents). + !*/ + + public: + + conditional_probability_table( + ); + /*! + ensures + - this object is properly initialized + !*/ + + void clear( + ); + /*! + ensures + - this object has its initial value + !*/ + + void empty_table ( + ); + /*! + ensures + - for all possible v and p: + - #has_entry_for(v,p) == false + (i.e. this function clears out the table when you call it but doesn't + change the value of num_values()) + !*/ + + void set_num_values ( + unsigned long num + ); + /*! + ensures + - #num_values() == num + - for all possible v and p: + - #has_entry_for(v,p) == false + (i.e. this function clears out the table when you call it) + !*/ + + unsigned long num_values ( + ) const; + /*! + ensures + - This object models the probability table p(X | parents). This + function returns the number of values X can take on. + !*/ + + bool has_entry_for ( + unsigned long value, + const assignment& ps + ) const; + /*! + ensures + - if (this conditional_probability_table has an entry for p(X = value, parents = ps)) then + - returns true + - else + - returns false + !*/ + + void set_probability ( + unsigned long value, + const assignment& ps, + double p + ); + /*! + requires + - value < num_values() + - 0 <= p <= 1 + ensures + - #probability(ps, value) == p + - #has_entry_for(value, ps) == true + !*/ + + double probability( + unsigned long value, + const assignment& ps + ) const; + /*! + requires + - value < num_values() + - has_entry_for(value, ps) == true + ensures + - returns the probability p( X = value | parents = ps). + !*/ + + void swap ( + conditional_probability_table& item + ); + /*! + ensures + - swaps *this and item + !*/ + }; + + inline void swap ( + conditional_probability_table& a, + conditional_probability_table& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + void serialize ( + const conditional_probability_table& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + conditional_probability_table& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ------------------------------------------------------------------------ +// ------------------------------------------------------------------------ +// ------------------------------------------------------------------------ + + class bayes_node : noncopyable + { + /*! + INITIAL VALUE + - is_evidence() == false + - value() == 0 + - table().num_values() == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents a node in a bayesian network. It is + intended to be used inside the dlib::directed_graph object to + represent bayesian networks. + !*/ + + public: + bayes_node ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + unsigned long value ( + ) const; + /*! + ensures + - returns the current value of this node + !*/ + + void set_value ( + unsigned long new_value + ); + /*! + requires + - new_value < table().num_values() + ensures + - #value() == new_value + !*/ + + conditional_probability_table& table ( + ); + /*! + ensures + - returns a reference to the conditional_probability_table associated with this node + !*/ + + const conditional_probability_table& table ( + ) const; + /*! + ensures + - returns a const reference to the conditional_probability_table associated with this + node. + !*/ + + bool is_evidence ( + ) const; + /*! + ensures + - if (this is an evidence node) then + - returns true + - else + - returns false + !*/ + + void set_as_nonevidence ( + ); + /*! + ensures + - #is_evidence() == false + !*/ + + void set_as_evidence ( + ); + /*! + ensures + - #is_evidence() == true + !*/ + + void swap ( + bayes_node& item + ); + /*! + ensures + - swaps *this and item + !*/ + + }; + + inline void swap ( + bayes_node& a, + bayes_node& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + + void serialize ( + const bayes_node& item, + std::ostream& out + ); + /*! + provides serialization support + !*/ + + void deserialize ( + bayes_node& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + /* + The following group of functions are convenience functions for manipulating + bayes_node objects while they are inside a directed_graph. These functions + also have additional requires clauses that, in debug mode, will protect you + from attempts to manipulate a bayesian network in an inappropriate way. + */ + + namespace bayes_node_utils + { + + template < + typename T + > + void set_node_value ( + T& bn, + unsigned long n, + unsigned long val + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - val < node_num_values(bn, n) + ensures + - #bn.node(n).data.value() = val + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + unsigned long node_value ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns bn.node(n).data.value() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + bool node_is_evidence ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns bn.node(n).data.is_evidence() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void set_node_as_evidence ( + T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - executes: bn.node(n).data.set_as_evidence() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void set_node_as_nonevidence ( + T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - executes: bn.node(n).data.set_as_nonevidence() + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void set_node_num_values ( + T& bn, + unsigned long n, + unsigned long num + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - #bn.node(n).data.table().num_values() == num + (i.e. sets the number of different values this node can take) + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + unsigned long node_num_values ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns bn.node(n).data.table().num_values() + (i.e. returns the number of different values this node can take) + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + const double node_probability ( + const T& bn, + unsigned long n, + unsigned long value, + const assignment& parents + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - value < node_num_values(bn,n) + - parents.size() == bn.node(n).number_of_parents() + - if (parents.has_index(x)) then + - bn.has_edge(x, n) + - parents[x] < node_num_values(bn,x) + ensures + - returns bn.node(n).data.table().probability(value, parents) + (i.e. returns the probability of node n having the given value when + its parents have the given assignment) + !*/ + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + const double set_node_probability ( + const T& bn, + unsigned long n, + unsigned long value, + const assignment& parents, + double p + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - value < node_num_values(bn,n) + - 0 <= p <= 1 + - parents.size() == bn.node(n).number_of_parents() + - if (parents.has_index(x)) then + - bn.has_edge(x, n) + - parents[x] < node_num_values(bn,x) + ensures + - #bn.node(n).data.table().probability(value, parents) == p + (i.e. sets the probability of node n having the given value when + its parents have the given assignment to the probability p) + !*/ + + // ------------------------------------------------------------------------------------ + + template + const assignment node_first_parent_assignment ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - returns an assignment A such that: + - A.size() == bn.node(n).number_of_parents() + - if (P is a parent of bn.node(n)) then + - A.has_index(P) + - A[P] == 0 + - I.e. this function returns an assignment that contains all + the parents of the given node. Also, all the values of each + parent in the assignment is set to zero. + !*/ + + // ------------------------------------------------------------------------------------ + + template + bool node_next_parent_assignment ( + const T& bn, + unsigned long n, + assignment& A + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + - A.size() == bn.node(n).number_of_parents() + - if (A.has_index(x)) then + - bn.has_edge(x, n) + - A[x] < node_num_values(bn,x) + ensures + - The behavior of this function is defined by the following code: + assignment a(node_first_parent_assignment(bn,n); + do { + // this loop loops over all possible parent assignments + // of the node bn.node(n). Each time through the loop variable a + // will be the next assignment. + } while (node_next_parent_assignment(bn,n,a)) + !*/ + + // ------------------------------------------------------------------------------------ + + template + bool node_cpt_filled_out ( + const T& bn, + unsigned long n + ); + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + - n < bn.number_of_nodes() + ensures + - if (the conditional_probability_table bn.node(n).data.table() is + fully filled out for this node) then + - returns true + - This means that each parent assignment for the given node + along with all possible values of this node shows up in the + table. + - It also means that all the probabilities conditioned on the + same parent assignment sum to 1.0 + - else + - returns false + !*/ + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + class bayesian_network_gibbs_sampler : noncopyable + { + /*! + INITIAL VALUE + This object has no state + + WHAT THIS OBJECT REPRESENTS + This object performs Markov Chain Monte Carlo sampling of a bayesian + network using the Gibbs sampling technique. + + Note that this object is limited to only bayesian networks that + don't contain deterministic nodes. That is, incorrect results may + be computed if this object is used when the bayesian network contains + any nodes that have a probability of 1 in their conditional probability + tables for any event. So don't use this object for networks with + deterministic nodes. + !*/ + public: + + bayesian_network_gibbs_sampler ( + ); + /*! + ensures + - this object is properly initialized + !*/ + + template < + typename T + > + void sample_graph ( + T& bn + ) + /*! + requires + - T is an implementation of directed_graph/directed_graph_kernel_abstract.h + - T::type == bayes_node + ensures + - modifies randomly (via the Gibbs sampling technique) samples all the nodes + in the network and updates their values with the newly sampled values + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + class bayesian_network_join_tree : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an implementation of the join tree algorithm + for inference in bayesian networks. It doesn't have any mutable state. + To you use you just give it a directed_graph that contains a bayesian + network and a graph object that contains that networks corresponding + join tree. Then you may query this object to determine the probabilities + of any variables in the original bayesian network. + !*/ + + public: + + template < + typename bn_type, + typename join_tree_type + > + bayesian_network_join_tree ( + const bn_type& bn, + const join_tree_type& join_tree + ); + /*! + requires + - bn_type is an implementation of directed_graph/directed_graph_kernel_abstract.h + - bn_type::type == bayes_node + - join_tree_type is an implementation of graph/graph_kernel_abstract.h + - join_tree_type::type is an implementation of set/set_compare_abstract.h and + this set type contains unsigned long objects. + - join_tree_type::edge_type is an implementation of set/set_compare_abstract.h and + this set type contains unsigned long objects. + - is_join_tree(bn, join_tree) == true + - bn == a valid bayesian network with all its conditional probability tables + filled out + - for all valid n: + - node_cpt_filled_out(bn,n) == true + - graph_contains_length_one_cycle(bn) == false + - graph_is_connected(bn) == true + - bn.number_of_nodes() > 0 + ensures + - this object is properly initialized + !*/ + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of nodes in the bayesian network that this + object was instantiated from. + !*/ + + const matrix probability( + unsigned long idx + ) const; + /*! + requires + - idx < number_of_nodes() + ensures + - returns the probability distribution for the node with index idx that was in the bayesian + network that *this was instantiated from. Let D represent this distribution, then: + - D.nc() == the number of values the node idx ranges over + - D.nr() == 1 + - D(i) == the probability of node idx taking on the value i + !*/ + + void swap ( + bayesian_network_join_tree& item + ); + /*! + ensures + - swaps *this with item + !*/ + + }; + + inline void swap ( + bayesian_network_join_tree& a, + bayesian_network_join_tree& b + ) { a.swap(b); } + /*! + provides a global swap + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BAYES_UTILs_ABSTRACT_ + + diff --git a/dlib/bigint.h b/dlib/bigint.h new file mode 100644 index 0000000000000000000000000000000000000000..73496689ac9129e51a3f1ec0268cda4a342523c1 --- /dev/null +++ b/dlib/bigint.h @@ -0,0 +1,43 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINt_ +#define DLIB_BIGINt_ + +#include "bigint/bigint_kernel_1.h" +#include "bigint/bigint_kernel_2.h" +#include "bigint/bigint_kernel_c.h" + + + + +namespace dlib +{ + + + class bigint + { + bigint() {} + + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef bigint_kernel_1 + kernel_1a; + typedef bigint_kernel_c + kernel_1a_c; + + // kernel_2a + typedef bigint_kernel_2 + kernel_2a; + typedef bigint_kernel_c + kernel_2a_c; + + + }; +} + +#endif // DLIB_BIGINt_ + diff --git a/dlib/bigint/bigint_kernel_1.cpp b/dlib/bigint/bigint_kernel_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..feef761c227996c1dd04e90087c1c04a86c65639 --- /dev/null +++ b/dlib/bigint/bigint_kernel_1.cpp @@ -0,0 +1,1720 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEL_1_CPp_ +#define DLIB_BIGINT_KERNEL_1_CPp_ +#include "bigint_kernel_1.h" + +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member/friend function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + bigint_kernel_1 ( + ) : + slack(25), + data(new data_record(slack)) + {} + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + bigint_kernel_1 ( + uint32 value + ) : + slack(25), + data(new data_record(slack)) + { + *(data->number) = static_cast(value&0xFFFF); + *(data->number+1) = static_cast((value>>16)&0xFFFF); + if (*(data->number+1) != 0) + data->digits_used = 2; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + bigint_kernel_1 ( + const bigint_kernel_1& item + ) : + slack(25), + data(item.data) + { + data->references += 1; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1:: + ~bigint_kernel_1 ( + ) + { + if (data->references == 1) + { + delete data; + } + else + { + data->references -= 1; + } + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator+ ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record ( + std::max(rhs.data->digits_used,data->digits_used) + slack + ); + long_add(data,rhs.data,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator+= ( + const bigint_kernel_1& rhs + ) + { + // if there are other references to our data + if (data->references != 1) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + data->references -= 1; + long_add(data,rhs.data,temp); + data = temp; + } + // if data is not big enough for the result + else if (data->size <= std::max(data->digits_used,rhs.data->digits_used)) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + long_add(data,rhs.data,temp); + delete data; + data = temp; + } + // there is enough size and no references + else + { + long_add(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator- ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + slack + ); + long_sub(data,rhs.data,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator-= ( + const bigint_kernel_1& rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + long_sub(data,rhs.data,temp); + data = temp; + } + else + { + long_sub(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator* ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + rhs.data->digits_used + slack + ); + long_mul(data,rhs.data,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator*= ( + const bigint_kernel_1& rhs + ) + { + // create a data_record to store the result of the multiplication in + data_record* temp = new data_record(rhs.data->digits_used+data->digits_used+slack); + long_mul(data,rhs.data,temp); + + // if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + else + { + delete data; + } + data = temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator/ ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete remainder; + + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator/= ( + const bigint_kernel_1& rhs + ) + { + + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = temp; + delete remainder; + + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator% ( + const bigint_kernel_1& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete temp; + return bigint_kernel_1(remainder,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator%= ( + const bigint_kernel_1& rhs + ) + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = remainder; + delete temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + operator < ( + const bigint_kernel_1& rhs + ) const + { + return is_less_than(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + operator == ( + const bigint_kernel_1& rhs + ) const + { + return is_equal_to(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator= ( + const bigint_kernel_1& rhs + ) + { + if (this == &rhs) + return *this; + + // if we have the only reference to our data then delete it + if (data->references == 1) + { + delete data; + data = rhs.data; + data->references += 1; + } + else + { + data->references -= 1; + data = rhs.data; + data->references += 1; + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + std::ostream& operator<< ( + std::ostream& out_, + const bigint_kernel_1& rhs + ) + { + std::ostream out(out_.rdbuf()); + + typedef bigint_kernel_1 bigint; + + bigint::data_record* temp = new bigint::data_record(*rhs.data,0); + + + + // get a char array big enough to hold the number in ascii format + char* str; + try { + str = new char[(rhs.data->digits_used)*5+10]; + } catch (...) { delete temp; throw; } + + char* str_start = str; + str += (rhs.data->digits_used)*5+9; + *str = 0; --str; + + + uint16 remainder; + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + + + // keep looping until temp represents zero + while (temp->digits_used != 1 || *(temp->number) != 0) + { + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + } + + // throw away and extra leading zeros + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + + + + + out << str; + delete [] str_start; + delete temp; + return out_; + + } + +// ---------------------------------------------------------------------------------------- + + std::istream& operator>> ( + std::istream& in_, + bigint_kernel_1& rhs + ) + { + std::istream in(in_.rdbuf()); + + // ignore any leading whitespaces + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n') + { + in.get(); + } + + // if the first digit is not an integer then this is an error + if ( !(in.peek() >= '0' && in.peek() <= '9')) + { + in_.clear(std::ios::failbit); + return in_; + } + + int num_read; + bigint_kernel_1 temp; + do + { + + // try to get 4 chars from in + num_read = 1; + char a = 0; + char b = 0; + char c = 0; + char d = 0; + + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + a = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + b = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + c = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + d = in.get(); + } + + // merge the for digits into an uint16 + uint16 num = 0; + if (a != 0) + { + num = a - '0'; + } + if (b != 0) + { + num *= 10; + num += b - '0'; + } + if (c != 0) + { + num *= 10; + num += c - '0'; + } + if (d != 0) + { + num *= 10; + num += d - '0'; + } + + + if (num_read != 1) + { + // shift the digits in temp left by the number of new digits we just read + temp *= num_read; + // add in new digits + temp += num; + } + + } while (num_read == 10000); + + + rhs = temp; + return in_; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator+ ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_add(rhs.data,lhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator+ ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_add(lhs.data,rhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator+= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_add(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_add(data,rhs,temp); + delete data; + data = temp; + } + // or if there is plenty of space and no references + else + { + short_add(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator- ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + *(temp->number) = lhs - *(rhs.data->number); + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator- ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_sub(lhs.data,rhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator-= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_sub(data,rhs,temp); + data = temp; + } + else + { + short_sub(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator* ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_mul(rhs.data,lhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator* ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_mul(lhs.data,rhs,temp); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator*= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_mul(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_mul(data,rhs,temp); + delete data; + data = temp; + } + else + { + short_mul(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator/ ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + // if rhs might not be bigger than lhs + if (rhs.data->digits_used == 1) + { + *(temp->number) = lhs/ *(rhs.data->number); + } + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator/ ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + uint16 remainder; + lhs.short_div(lhs.data,rhs,temp,remainder); + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator/= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator% ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + typedef bigint_kernel_1 bigint; + // temp is zero by default + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + if (rhs.data->digits_used == 1) + { + // if rhs is just an uint16 inside then perform the modulus + *(temp->number) = lhs % *(rhs.data->number); + } + else + { + // if rhs is bigger than lhs then the answer is lhs + *(temp->number) = lhs; + } + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 operator% ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_1 bigint; + bigint::data_record* temp = new bigint::data_record(lhs.data->digits_used+lhs.slack); + + uint16 remainder; + + lhs.short_div(lhs.data,rhs,temp,remainder); + temp->digits_used = 1; + *(temp->number) = remainder; + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator%= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + + data->digits_used = 1; + *(data->number) = remainder; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + return (rhs.data->digits_used > 1 || lhs < *(rhs.data->number) ); + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) < rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + const bigint_kernel_1& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) == rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + uint16 lhs, + const bigint_kernel_1& rhs + ) + { + return (rhs.data->digits_used == 1 && *(rhs.data->number) == lhs); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator= ( + uint16 rhs + ) + { + // check if there are other references to our data + if (data->references != 1) + { + data->references -= 1; + try { + data = new data_record(slack); + } catch (...) { data->references += 1; throw; } + } + else + { + data->digits_used = 1; + } + + *(data->number) = rhs; + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator++ ( + ) + { + // if there are other references to this data then make a copy of it + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + increment(data,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + increment(data,temp); + delete data; + data = temp; + } + else + { + increment(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator++ ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + increment(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_1& bigint_kernel_1:: + operator-- ( + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + decrement(data,temp); + data = temp; + } + else + { + decrement(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_1 bigint_kernel_1:: + operator-- ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + decrement(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_1(temp,0); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp = value; + temp <<= 16; + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used; // one past the end of number + uint16* r = result->number; + + while (number != end) + { + // add *number and the current carry + temp = *number + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // store the carry in the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used - 1; + uint16* r = result->number; + + uint32 temp = *number - value; + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + + while (number != end) + { + ++number; + ++r; + + // subtract the carry from *number + temp = *number - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + } + + // if we lost a digit in the subtraction + if (*r == 0) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + uint32 temp = 0; + + + const uint16* number = data->number; + uint16* r = result->number; + const uint16* end = r + data->digits_used; + + + + while ( r != end) + { + + // multiply *data and value and add in the carry + temp = *number*(uint32)value + (temp>>16); + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // put the final carry into the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& rem + ) const + { + + uint16 remainder = 0; + uint32 temp; + + + + const uint16* number = data->number + data->digits_used - 1; + const uint16* end = number - data->digits_used; + uint16* r = result->number + data->digits_used - 1; + + + // if we are losing a digit in this division + if (*number < value) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + // perform the actual division + while (number != end) + { + + temp = *number + (((uint32)remainder)<<16); + + *r = static_cast(temp/value); + remainder = static_cast(temp%value); + + --number; + --r; + } + + rem = remainder; + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp=0; + + uint16* min_num; // the number with the least digits used + uint16* max_num; // the number with the most digits used + uint16* min_end; // one past the end of min_num + uint16* max_end; // one past the end of max_num + uint16* r = result->number; + + uint32 max_digits_used; + if (lhs->digits_used < rhs->digits_used) + { + max_digits_used = rhs->digits_used; + min_num = lhs->number; + max_num = rhs->number; + min_end = min_num + lhs->digits_used; + max_end = max_num + rhs->digits_used; + } + else + { + max_digits_used = lhs->digits_used; + min_num = rhs->number; + max_num = lhs->number; + min_end = min_num + rhs->digits_used; + max_end = max_num + lhs->digits_used; + } + + + + + while (min_num != min_end) + { + // add *min_num, *max_num and the current carry + temp = *min_num + *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++min_num; + ++max_num; + ++r; + } + + + while (max_num != max_end) + { + // add *max_num and the current carry + temp = *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++max_num; + ++r; + } + + // check if there was a final carry + if ((temp>>16) != 0) + { + result->digits_used = max_digits_used + 1; + // put the carry into the most significant digit in the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = max_digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + + + const uint16* number1 = lhs->number; + const uint16* number2 = rhs->number; + const uint16* end = number2 + rhs->digits_used; + uint16* r = result->number; + + + + uint32 temp =0; + + + while (number2 != end) + { + + // subtract *number2 from *number1 and then subtract any carry + temp = *number1 - *number2 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++number2; + ++r; + } + + end = lhs->number + lhs->digits_used; + while (number1 != end) + { + + // subtract the carry from *number1 + temp = *number1 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++r; + } + + result->digits_used = lhs->digits_used; + // adjust the number of digits used appropriately + --r; + while (*r == 0 && result->digits_used > 1) + { + --r; + --result->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const + { + // zero result + result->digits_used = 1; + *(result->number) = 0; + + uint16* a; + uint16* b; + uint16* end; + + // copy lhs into remainder + remainder->digits_used = lhs->digits_used; + a = remainder->number; + end = a + remainder->digits_used; + b = lhs->number; + while (a != end) + { + *a = *b; + ++a; + ++b; + } + + + // if rhs is bigger than lhs then result == 0 and remainder == lhs + // so then we can quit right now + if (is_less_than(lhs,rhs)) + { + return; + } + + + // make a temporary number + data_record temp(lhs->digits_used + slack); + + + // shift rhs left until it is one shift away from being larger than lhs and + // put the number of left shifts necessary into shifts + uint32 shifts; + shifts = (lhs->digits_used - rhs->digits_used) * 16; + + shift_left(rhs,&temp,shifts); + + + // while (lhs > temp) + while (is_less_than(&temp,lhs)) + { + shift_left(&temp,&temp,1); + ++shifts; + } + // make sure lhs isn't smaller than temp + while (is_less_than(lhs,&temp)) + { + shift_right(&temp,&temp); + --shifts; + } + + + + // we want to execute the loop shifts +1 times + ++shifts; + while (shifts != 0) + { + shift_left(result,result,1); + // if (temp <= remainder) + if (!is_less_than(remainder,&temp)) + { + long_sub(remainder,&temp,remainder); + + // increment result + uint16* r = result->number; + uint16* end = r + result->digits_used; + while (true) + { + ++(*r); + // if there was no carry then we are done + if (*r != 0) + break; + + ++r; + + // if we hit the end of r and there is still a carry then + // the next digit of r is 1 and there is one more digit used + if (r == end) + { + *r = 1; + ++(result->digits_used); + break; + } + } + } + shift_right(&temp,&temp); + --shifts; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // make result be zero + result->digits_used = 1; + *(result->number) = 0; + + + const data_record* aa; + const data_record* bb; + + if (lhs->digits_used < rhs->digits_used) + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = lhs; + bb = rhs; + } + else + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = rhs; + bb = lhs; + } + // this is where we actually copy lhs and rhs + data_record b(*bb,aa->digits_used+slack); // the larger(approximately) of lhs and rhs + + + uint32 shift_value = 0; + uint16* anum = aa->number; + uint16* end = anum + aa->digits_used; + while (anum != end ) + { + uint16 bit = 0x0001; + + for (int i = 0; i < 16; ++i) + { + // if the specified bit of a is 1 + if ((*anum & bit) != 0) + { + shift_left(&b,&b,shift_value); + shift_value = 0; + long_add(&b,result,result); + } + ++shift_value; + bit <<= 1; + } + + ++anum; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const + { + uint32 offset = shift_amount/16; + shift_amount &= 0xf; // same as shift_amount %= 16; + + uint16* r = result->number + data->digits_used + offset; // result + uint16* end = data->number; + uint16* s = end + data->digits_used; // source + const uint32 temp = 16 - shift_amount; + + *r = (*(--s) >> temp); + // set the number of digits used in the result + // if the upper bits from *s were zero then don't count this first word + if (*r == 0) + { + result->digits_used = data->digits_used + offset; + } + else + { + result->digits_used = data->digits_used + offset + 1; + } + --r; + + while (s != end) + { + *r = ((*s << shift_amount) | ( *(s-1) >> temp)); + --r; + --s; + } + *r = *s << shift_amount; + + // now zero the rest of the result + end = result->number; + while (r != end) + *(--r) = 0; + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + shift_right ( + const data_record* data, + data_record* result + ) const + { + + uint16* r = result->number; // result + uint16* s = data->number; // source + uint16* end = s + data->digits_used - 1; + + while (s != end) + { + *r = (*s >> 1) | (*(s+1) << 15); + ++r; + ++s; + } + *r = *s >> 1; + + + // calculate the new number for digits_used + if (*r == 0) + { + if (data->digits_used != 1) + result->digits_used = data->digits_used - 1; + else + result->digits_used = 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const + { + uint32 lhs_digits_used = lhs->digits_used; + uint32 rhs_digits_used = rhs->digits_used; + + // if lhs is definitely less than rhs + if (lhs_digits_used < rhs_digits_used ) + return true; + // if lhs is definitely greater than rhs + else if (lhs_digits_used > rhs_digits_used) + return false; + else + { + uint16* end = lhs->number; + uint16* l = end + lhs_digits_used; + uint16* r = rhs->number + rhs_digits_used; + + while (l != end) + { + --l; + --r; + if (*l < *r) + return true; + else if (*l > *r) + return false; + } + + // at this point we know that they are equal + return false; + } + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_1:: + is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const + { + // if lhs and rhs are definitely not equal + if (lhs->digits_used != rhs->digits_used ) + { + return false; + } + else + { + uint16* l = lhs->number; + uint16* r = rhs->number; + uint16* end = l + lhs->digits_used; + + while (l != end) + { + if (*l != *r) + return false; + ++l; + ++r; + } + + // at this point we know that they are equal + return true; + } + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + increment ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s + 1; + + // if there was no carry then break out of the loop + if (*d != 0) + { + dest->digits_used = source->digits_used; + + // copy the rest of the digits over to d + ++d; ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + + break; + } + + + ++s; + + // if we have hit the end of s and there was a carry up to this point + // then just make the next digit 1 and add one to the digits used + if (s == end) + { + ++d; + dest->digits_used = source->digits_used + 1; + *d = 1; + break; + } + + ++d; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_1:: + decrement ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s - 1; + + // if there was no carry then break out of the loop + if (*d != 0xFFFF) + { + // if we lost a digit in the subtraction + if (*d == 0 && s+1 == end) + { + if (source->digits_used == 1) + dest->digits_used = 1; + else + dest->digits_used = source->digits_used - 1; + } + else + { + dest->digits_used = source->digits_used; + } + break; + } + else + { + ++d; + ++s; + } + + } + + // copy the rest of the digits over to d + ++d; + ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_BIGINT_KERNEL_1_CPp_ + diff --git a/dlib/bigint/bigint_kernel_1.h b/dlib/bigint/bigint_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..32463c3b26f017d9dabdada92290945251fe2434 --- /dev/null +++ b/dlib/bigint/bigint_kernel_1.h @@ -0,0 +1,543 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEl_1_ +#define DLIB_BIGINT_KERNEl_1_ + +#include "bigint_kernel_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "../uintn.h" +#include + +namespace dlib +{ + + + class bigint_kernel_1 + { + /*! + INITIAL VALUE + slack == 25 + data->number[0] == 0 + data->size == slack + data->references == 1 + data->digits_used == 1 + + + CONVENTION + slack == the number of extra digits placed into the number when it is + created. the slack value should never be less than 1 + + data->number == pointer to an array of data->size uint16s. + data represents a string of base 65535 numbers with data[0] being + the least significant bit and data[data->digits_used-1] being the most + significant + + + NOTE: In the comments I will consider a word to be a 16 bit value + + + data->digits_used == the number of significant digits in the number. + data->digits_used tells us the number of used elements in the + data->number array so everything beyond data->number[data->digits_used-1] + is undefined + + data->references == the number of bigint_kernel_1 objects which refer + to this data_record + + + + !*/ + + + struct data_record + { + + + explicit data_record( + uint32 size_ + ) : + size(size_), + number(new uint16[size_]), + references(1), + digits_used(1) + {*number = 0;} + /*! + ensures + - initializes *this to represent zero + !*/ + + data_record( + const data_record& item, + uint32 additional_size + ) : + size(item.digits_used + additional_size), + number(new uint16[size]), + references(1), + digits_used(item.digits_used) + { + uint16* source = item.number; + uint16* dest = number; + uint16* end = source + digits_used; + while (source != end) + { + *dest = *source; + ++dest; + ++source; + } + } + /*! + ensures + - *this is a copy of item except with + size == item.digits_used + additional_size + !*/ + + ~data_record( + ) + { + delete [] number; + } + + + const uint32 size; + uint16* number; + uint32 references; + uint32 digits_used; + + private: + // no copy constructor + data_record ( data_record&); + }; + + + + // note that the second parameter is just there + // to resolve the ambiguity between this constructor and + // bigint_kernel_1(uint32) + explicit bigint_kernel_1 ( + data_record* data_, int + ): slack(25),data(data_) {} + /*! + ensures + - *this is initialized with data_ as its data member + !*/ + + + public: + + bigint_kernel_1 ( + ); + + bigint_kernel_1 ( + uint32 value + ); + + bigint_kernel_1 ( + const bigint_kernel_1& item + ); + + virtual ~bigint_kernel_1 ( + ); + + const bigint_kernel_1 operator+ ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator+= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator- ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator-= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator* ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator*= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator/ ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator/= ( + const bigint_kernel_1& rhs + ); + + const bigint_kernel_1 operator% ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator%= ( + const bigint_kernel_1& rhs + ); + + bool operator < ( + const bigint_kernel_1& rhs + ) const; + + bool operator == ( + const bigint_kernel_1& rhs + ) const; + + bigint_kernel_1& operator= ( + const bigint_kernel_1& rhs + ); + + friend std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_1& rhs + ); + + friend std::istream& operator>> ( + std::istream& in, + bigint_kernel_1& rhs + ); + + bigint_kernel_1& operator++ ( + ); + + const bigint_kernel_1 operator++ ( + int + ); + + bigint_kernel_1& operator-- ( + ); + + const bigint_kernel_1 operator-- ( + int + ); + + friend const bigint_kernel_1 operator+ ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator+ ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator+= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator- ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator- ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator-= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator* ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator* ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator*= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator/ ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator/ ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator/= ( + uint16 rhs + ); + + friend const bigint_kernel_1 operator% ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend const bigint_kernel_1 operator% ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + bigint_kernel_1& operator%= ( + uint16 rhs + ); + + friend bool operator < ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + friend bool operator < ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + friend bool operator == ( + const bigint_kernel_1& lhs, + uint16 rhs + ); + + friend bool operator == ( + uint16 lhs, + const bigint_kernel_1& rhs + ); + + bigint_kernel_1& operator= ( + uint16 rhs + ); + + + void swap ( + bigint_kernel_1& item + ) { data_record* temp = data; data = item.data; item.data = temp; } + + + private: + + void long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= max(lhs->digits_used,rhs->digits_used) + 1 + ensures + - result == lhs + rhs + !*/ + + void long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - lhs >= rhs + - result->size >= lhs->digits_used + ensures + - result == lhs - rhs + !*/ + + void long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const; + /*! + requires + - rhs != 0 + - result->size >= lhs->digits_used + - remainder->size >= lhs->digits_used + - each parameter is unique (i.e. lhs != result, lhs != remainder, etc.) + ensures + - result == lhs / rhs + - remainder == lhs % rhs + !*/ + + void long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= lhs->digits_used + rhs->digits_used + - result != lhs + - result != rhs + ensures + - result == lhs * rhs + !*/ + + void short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->size + 1 + ensures + - result == data + value + !*/ + + void short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - data >= value + - result->size >= data->digits_used + ensures + - result == data - value + !*/ + + void short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + 1 + ensures + - result == data * value + !*/ + + void short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& remainder + ) const; + /*! + requires + - value != 0 + - result->size >= data->digits_used + ensures + - result = data*value + - remainder = data%value + !*/ + + void shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const; + /*! + requires + - result->size >= data->digits_used + shift_amount/8 + 1 + ensures + - result == data << shift_amount + !*/ + + void shift_right ( + const data_record* data, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + ensures + - result == data >> 1 + !*/ + + bool is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs < rhs + - returns false otherwise + !*/ + + bool is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs == rhs + - returns false otherwise + !*/ + + void increment ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + - dest->size >= source->digits_used + 1 + ensures + - dest = source + 1 + !*/ + + void decrement ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + source != 0 + ensuers + dest = source - 1 + !*/ + + // member data + const uint32 slack; + data_record* data; + + + + }; + + inline void swap ( + bigint_kernel_1& a, + bigint_kernel_1& b + ) { a.swap(b); } + + inline void serialize ( + const bigint_kernel_1& item, + std::ostream& out + ) + { + std::ios::fmtflags oldflags = out.flags(); + out << item << ' '; + out.flags(oldflags); + if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); + } + + inline void deserialize ( + bigint_kernel_1& item, + std::istream& in + ) + { + std::ios::fmtflags oldflags = in.flags(); + in >> item; + in.flags(oldflags); + if (in.get() != ' ') + { + item = 0; + throw serialization_error("Error deserializing object of type bigint_kernel_c"); + } + } + + inline bool operator> (const bigint_kernel_1& a, const bigint_kernel_1& b) { return b < a; } + inline bool operator!= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a == b); } + inline bool operator<= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(b < a); } + inline bool operator>= (const bigint_kernel_1& a, const bigint_kernel_1& b) { return !(a < b); } +} + +#ifdef NO_MAKEFILE +#include "bigint_kernel_1.cpp" +#endif + +#endif // DLIB_BIGINT_KERNEl_1_ + diff --git a/dlib/bigint/bigint_kernel_2.cpp b/dlib/bigint/bigint_kernel_2.cpp new file mode 100644 index 0000000000000000000000000000000000000000..694b1ae1f752146f8539bb5f9e0735e2c5e701d3 --- /dev/null +++ b/dlib/bigint/bigint_kernel_2.cpp @@ -0,0 +1,1945 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEL_2_CPp_ +#define DLIB_BIGINT_KERNEL_2_CPp_ +#include "bigint_kernel_2.h" + +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member/friend function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + bigint_kernel_2 ( + ) : + slack(25), + data(new data_record(slack)) + {} + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + bigint_kernel_2 ( + uint32 value + ) : + slack(25), + data(new data_record(slack)) + { + *(data->number) = static_cast(value&0xFFFF); + *(data->number+1) = static_cast((value>>16)&0xFFFF); + if (*(data->number+1) != 0) + data->digits_used = 2; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + bigint_kernel_2 ( + const bigint_kernel_2& item + ) : + slack(25), + data(item.data) + { + data->references += 1; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2:: + ~bigint_kernel_2 ( + ) + { + if (data->references == 1) + { + delete data; + } + else + { + data->references -= 1; + } + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator+ ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record ( + std::max(rhs.data->digits_used,data->digits_used) + slack + ); + long_add(data,rhs.data,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator+= ( + const bigint_kernel_2& rhs + ) + { + // if there are other references to our data + if (data->references != 1) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + data->references -= 1; + long_add(data,rhs.data,temp); + data = temp; + } + // if data is not big enough for the result + else if (data->size <= std::max(data->digits_used,rhs.data->digits_used)) + { + data_record* temp = new data_record(std::max(data->digits_used,rhs.data->digits_used)+slack); + long_add(data,rhs.data,temp); + delete data; + data = temp; + } + // there is enough size and no references + else + { + long_add(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator- ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + slack + ); + long_sub(data,rhs.data,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator-= ( + const bigint_kernel_2& rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + long_sub(data,rhs.data,temp); + data = temp; + } + else + { + long_sub(data,rhs.data,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator* ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record ( + data->digits_used + rhs.data->digits_used + slack + ); + long_mul(data,rhs.data,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator*= ( + const bigint_kernel_2& rhs + ) + { + // create a data_record to store the result of the multiplication in + data_record* temp = new data_record(rhs.data->digits_used+data->digits_used+slack); + long_mul(data,rhs.data,temp); + + // if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + else + { + delete data; + } + data = temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator/ ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete remainder; + + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator/= ( + const bigint_kernel_2& rhs + ) + { + + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = temp; + delete remainder; + + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator% ( + const bigint_kernel_2& rhs + ) const + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + delete temp; + return bigint_kernel_2(remainder,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator%= ( + const bigint_kernel_2& rhs + ) + { + data_record* temp = new data_record(data->digits_used+slack); + data_record* remainder; + try { + remainder = new data_record(data->digits_used+slack); + } catch (...) { delete temp; throw; } + + long_div(data,rhs.data,temp,remainder); + + // check if there are other references to data + if (data->references != 1) + { + data->references -= 1; + } + // if there are no references to data then it must be deleted + else + { + delete data; + } + data = remainder; + delete temp; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + operator < ( + const bigint_kernel_2& rhs + ) const + { + return is_less_than(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + operator == ( + const bigint_kernel_2& rhs + ) const + { + return is_equal_to(data,rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator= ( + const bigint_kernel_2& rhs + ) + { + if (this == &rhs) + return *this; + + // if we have the only reference to our data then delete it + if (data->references == 1) + { + delete data; + data = rhs.data; + data->references += 1; + } + else + { + data->references -= 1; + data = rhs.data; + data->references += 1; + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + std::ostream& operator<< ( + std::ostream& out_, + const bigint_kernel_2& rhs + ) + { + std::ostream out(out_.rdbuf()); + + typedef bigint_kernel_2 bigint; + + bigint::data_record* temp = new bigint::data_record(*rhs.data,0); + + + + // get a char array big enough to hold the number in ascii format + char* str; + try { + str = new char[(rhs.data->digits_used)*5+10]; + } catch (...) { delete temp; throw; } + + char* str_start = str; + str += (rhs.data->digits_used)*5+9; + *str = 0; --str; + + + uint16 remainder; + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + + + // keep looping until temp represents zero + while (temp->digits_used != 1 || *(temp->number) != 0) + { + rhs.short_div(temp,10000,temp,remainder); + + // pull the digits out of remainder + char a = remainder % 10 + '0'; + remainder /= 10; + char b = remainder % 10 + '0'; + remainder /= 10; + char c = remainder % 10 + '0'; + remainder /= 10; + char d = remainder % 10 + '0'; + remainder /= 10; + + *str = a; --str; + *str = b; --str; + *str = c; --str; + *str = d; --str; + } + + // throw away and extra leading zeros + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + if (*str == '0') + ++str; + + + + + out << str; + delete [] str_start; + delete temp; + return out_; + + } + +// ---------------------------------------------------------------------------------------- + + std::istream& operator>> ( + std::istream& in_, + bigint_kernel_2& rhs + ) + { + std::istream in(in_.rdbuf()); + + // ignore any leading whitespaces + while (in.peek() == ' ' || in.peek() == '\t' || in.peek() == '\n') + { + in.get(); + } + + // if the first digit is not an integer then this is an error + if ( !(in.peek() >= '0' && in.peek() <= '9')) + { + in_.clear(std::ios::failbit); + return in_; + } + + int num_read; + bigint_kernel_2 temp; + do + { + + // try to get 4 chars from in + num_read = 1; + char a = 0; + char b = 0; + char c = 0; + char d = 0; + + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + a = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + b = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + c = in.get(); + } + if (in.peek() >= '0' && in.peek() <= '9') + { + num_read *= 10; + d = in.get(); + } + + // merge the for digits into an uint16 + uint16 num = 0; + if (a != 0) + { + num = a - '0'; + } + if (b != 0) + { + num *= 10; + num += b - '0'; + } + if (c != 0) + { + num *= 10; + num += c - '0'; + } + if (d != 0) + { + num *= 10; + num += d - '0'; + } + + + if (num_read != 1) + { + // shift the digits in temp left by the number of new digits we just read + temp *= num_read; + // add in new digits + temp += num; + } + + } while (num_read == 10000); + + + rhs = temp; + return in_; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator+ ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_add(rhs.data,lhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator+ ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_add(lhs.data,rhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator+= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_add(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_add(data,rhs,temp); + delete data; + data = temp; + } + // or if there is plenty of space and no references + else + { + short_add(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator- ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + *(temp->number) = lhs - *(rhs.data->number); + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator- ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_sub(lhs.data,rhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator-= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_sub(data,rhs,temp); + data = temp; + } + else + { + short_sub(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator* ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (rhs.data->digits_used+rhs.slack); + + rhs.short_mul(rhs.data,lhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator* ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + lhs.short_mul(lhs.data,rhs,temp); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator*= ( + uint16 rhs + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_mul(data,rhs,temp); + data = temp; + } + // or if we need to enlarge data + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + short_mul(data,rhs,temp); + delete data; + data = temp; + } + else + { + short_mul(data,rhs,data); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator/ ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + // if rhs might not be bigger than lhs + if (rhs.data->digits_used == 1) + { + *(temp->number) = lhs/ *(rhs.data->number); + } + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator/ ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record + (lhs.data->digits_used+lhs.slack); + + uint16 remainder; + lhs.short_div(lhs.data,rhs,temp,remainder); + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator/= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator% ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + typedef bigint_kernel_2 bigint; + // temp is zero by default + bigint::data_record* temp = new bigint::data_record(rhs.slack); + + if (rhs.data->digits_used == 1) + { + // if rhs is just an uint16 inside then perform the modulus + *(temp->number) = lhs % *(rhs.data->number); + } + else + { + // if rhs is bigger than lhs then the answer is lhs + *(temp->number) = lhs; + } + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 operator% ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + typedef bigint_kernel_2 bigint; + bigint::data_record* temp = new bigint::data_record(lhs.data->digits_used+lhs.slack); + + uint16 remainder; + + lhs.short_div(lhs.data,rhs,temp,remainder); + temp->digits_used = 1; + *(temp->number) = remainder; + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator%= ( + uint16 rhs + ) + { + uint16 remainder; + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + short_div(data,rhs,temp,remainder); + data = temp; + } + else + { + short_div(data,rhs,data,remainder); + } + + data->digits_used = 1; + *(data->number) = remainder; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + return (rhs.data->digits_used > 1 || lhs < *(rhs.data->number) ); + } + +// ---------------------------------------------------------------------------------------- + + bool operator < ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) < rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + const bigint_kernel_2& lhs, + uint16 rhs + ) + { + return (lhs.data->digits_used == 1 && *(lhs.data->number) == rhs); + } + +// ---------------------------------------------------------------------------------------- + + bool operator == ( + uint16 lhs, + const bigint_kernel_2& rhs + ) + { + return (rhs.data->digits_used == 1 && *(rhs.data->number) == lhs); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator= ( + uint16 rhs + ) + { + // check if there are other references to our data + if (data->references != 1) + { + data->references -= 1; + try { + data = new data_record(slack); + } catch (...) { data->references += 1; throw; } + } + else + { + data->digits_used = 1; + } + + *(data->number) = rhs; + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator++ ( + ) + { + // if there are other references to this data then make a copy of it + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + increment(data,temp); + data = temp; + } + // or if we need to enlarge data then do so + else if (data->digits_used == data->size) + { + data_record* temp = new data_record(data->digits_used+slack); + increment(data,temp); + delete data; + data = temp; + } + else + { + increment(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator++ ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + increment(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- + + bigint_kernel_2& bigint_kernel_2:: + operator-- ( + ) + { + // if there are other references to this data + if (data->references != 1) + { + data_record* temp = new data_record(data->digits_used+slack); + data->references -= 1; + decrement(data,temp); + data = temp; + } + else + { + decrement(data,data); + } + + return *this; + } + +// ---------------------------------------------------------------------------------------- + + const bigint_kernel_2 bigint_kernel_2:: + operator-- ( + int + ) + { + data_record* temp; // this is the copy of temp we will return in the end + + data_record* temp2 = new data_record(data->digits_used+slack); + decrement(data,temp2); + + temp = data; + data = temp2; + + return bigint_kernel_2(temp,0); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp = value; + temp <<= 16; + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used; // one past the end of number + uint16* r = result->number; + + while (number != end) + { + // add *number and the current carry + temp = *number + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // store the carry in the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + + const uint16* number = data->number; + const uint16* end = number + data->digits_used - 1; + uint16* r = result->number; + + uint32 temp = *number - value; + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + + while (number != end) + { + ++number; + ++r; + + // subtract the carry from *number + temp = *number - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + } + + // if we lost a digit in the subtraction + if (*r == 0) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const + { + + uint32 temp = 0; + + + const uint16* number = data->number; + uint16* r = result->number; + const uint16* end = r + data->digits_used; + + + + while ( r != end) + { + + // multiply *data and value and add in the carry + temp = *number*(uint32)value + (temp>>16); + + // put the low word of temp into *data + *r = static_cast(temp & 0xFFFF); + + ++number; + ++r; + } + + // if there is a final carry + if ((temp>>16) != 0) + { + result->digits_used = data->digits_used + 1; + // put the final carry into the most significant digit of the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& rem + ) const + { + + uint16 remainder = 0; + uint32 temp; + + + + const uint16* number = data->number + data->digits_used - 1; + const uint16* end = number - data->digits_used; + uint16* r = result->number + data->digits_used - 1; + + + // if we are losing a digit in this division + if (*number < value) + { + if (data->digits_used == 1) + result->digits_used = 1; + else + result->digits_used = data->digits_used - 1; + } + else + { + result->digits_used = data->digits_used; + } + + + // perform the actual division + while (number != end) + { + + temp = *number + (((uint32)remainder)<<16); + + *r = static_cast(temp/value); + remainder = static_cast(temp%value); + + --number; + --r; + } + + rem = remainder; + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // put value into the carry part of temp + uint32 temp=0; + + uint16* min_num; // the number with the least digits used + uint16* max_num; // the number with the most digits used + uint16* min_end; // one past the end of min_num + uint16* max_end; // one past the end of max_num + uint16* r = result->number; + + uint32 max_digits_used; + if (lhs->digits_used < rhs->digits_used) + { + max_digits_used = rhs->digits_used; + min_num = lhs->number; + max_num = rhs->number; + min_end = min_num + lhs->digits_used; + max_end = max_num + rhs->digits_used; + } + else + { + max_digits_used = lhs->digits_used; + min_num = rhs->number; + max_num = lhs->number; + min_end = min_num + rhs->digits_used; + max_end = max_num + lhs->digits_used; + } + + + + + while (min_num != min_end) + { + // add *min_num, *max_num and the current carry + temp = *min_num + *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++min_num; + ++max_num; + ++r; + } + + + while (max_num != max_end) + { + // add *max_num and the current carry + temp = *max_num + (temp>>16); + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++max_num; + ++r; + } + + // check if there was a final carry + if ((temp>>16) != 0) + { + result->digits_used = max_digits_used + 1; + // put the carry into the most significant digit in the result + *r = static_cast(temp>>16); + } + else + { + result->digits_used = max_digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + + + const uint16* number1 = lhs->number; + const uint16* number2 = rhs->number; + const uint16* end = number2 + rhs->digits_used; + uint16* r = result->number; + + + + uint32 temp =0; + + + while (number2 != end) + { + + // subtract *number2 from *number1 and then subtract any carry + temp = *number1 - *number2 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++number2; + ++r; + } + + end = lhs->number + lhs->digits_used; + while (number1 != end) + { + + // subtract the carry from *number1 + temp = *number1 - (temp>>31); + + // put the low word of temp into *r + *r = static_cast(temp & 0xFFFF); + + ++number1; + ++r; + } + + result->digits_used = lhs->digits_used; + // adjust the number of digits used appropriately + --r; + while (*r == 0 && result->digits_used > 1) + { + --r; + --result->digits_used; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const + { + // zero result + result->digits_used = 1; + *(result->number) = 0; + + uint16* a; + uint16* b; + uint16* end; + + // copy lhs into remainder + remainder->digits_used = lhs->digits_used; + a = remainder->number; + end = a + remainder->digits_used; + b = lhs->number; + while (a != end) + { + *a = *b; + ++a; + ++b; + } + + + // if rhs is bigger than lhs then result == 0 and remainder == lhs + // so then we can quit right now + if (is_less_than(lhs,rhs)) + { + return; + } + + + // make a temporary number + data_record temp(lhs->digits_used + slack); + + + // shift rhs left until it is one shift away from being larger than lhs and + // put the number of left shifts necessary into shifts + uint32 shifts; + shifts = (lhs->digits_used - rhs->digits_used) * 16; + + shift_left(rhs,&temp,shifts); + + + // while (lhs > temp) + while (is_less_than(&temp,lhs)) + { + shift_left(&temp,&temp,1); + ++shifts; + } + // make sure lhs isn't smaller than temp + while (is_less_than(lhs,&temp)) + { + shift_right(&temp,&temp); + --shifts; + } + + + + // we want to execute the loop shifts +1 times + ++shifts; + while (shifts != 0) + { + shift_left(result,result,1); + // if (temp <= remainder) + if (!is_less_than(remainder,&temp)) + { + long_sub(remainder,&temp,remainder); + + // increment result + uint16* r = result->number; + uint16* end = r + result->digits_used; + while (true) + { + ++(*r); + // if there was no carry then we are done + if (*r != 0) + break; + + ++r; + + // if we hit the end of r and there is still a carry then + // the next digit of r is 1 and there is one more digit used + if (r == end) + { + *r = 1; + ++(result->digits_used); + break; + } + } + } + shift_right(&temp,&temp); + --shifts; + } + + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const + { + // if one of the numbers is small then use this simple but O(n^2) algorithm + if (std::min(lhs->digits_used, rhs->digits_used) < 10) + { + // make result be zero + result->digits_used = 1; + *(result->number) = 0; + + + const data_record* aa; + const data_record* bb; + + if (lhs->digits_used < rhs->digits_used) + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = lhs; + bb = rhs; + } + else + { + // make copies of lhs and rhs and give them an appropriate amount of + // extra memory so there won't be any overflows + aa = rhs; + bb = lhs; + } + + // copy the larger(approximately) of lhs and rhs into b + data_record b(*bb,aa->digits_used+slack); + + + uint32 shift_value = 0; + uint16* anum = aa->number; + uint16* end = anum + aa->digits_used; + while (anum != end ) + { + uint16 bit = 0x0001; + + for (int i = 0; i < 16; ++i) + { + // if the specified bit of a is 1 + if ((*anum & bit) != 0) + { + shift_left(&b,&b,shift_value); + shift_value = 0; + long_add(&b,result,result); + } + ++shift_value; + bit <<= 1; + } + + ++anum; + } + } + else // else if both lhs and rhs are large then use the more complex + // O(n*logn) algorithm + { + uint32 size = 1; + // make size a power of 2 + while (size < (lhs->digits_used + rhs->digits_used)*2) + { + size *= 2; + } + + // allocate some temporary space so we can do the FFT + ct* a = new ct[size]; + ct* b; try {b = new ct[size]; } catch (...) { delete [] a; throw; } + + // load lhs into the a array. We are breaking the input number into + // 8bit chunks for the purpose of using this fft algorithm. The reason + // for this is so that we have smaller numbers coming out of the final + // ifft. This helps avoid overflow. + for (uint32 i = 0; i < lhs->digits_used; ++i) + { + a[i*2] = ct((t)(lhs->number[i]&0xFF),0); + a[i*2+1] = ct((t)(lhs->number[i]>>8),0); + } + for (uint32 i = lhs->digits_used*2; i < size; ++i) + { + a[i] = 0; + } + + // load rhs into the b array + for (uint32 i = 0; i < rhs->digits_used; ++i) + { + b[i*2] = ct((t)(rhs->number[i]&0xFF),0); + b[i*2+1] = ct((t)(rhs->number[i]>>8),0); + } + for (uint32 i = rhs->digits_used*2; i < size; ++i) + { + b[i] = 0; + } + + // perform the forward fft of a and b + fft(a,size); + fft(b,size); + + const double l = 1.0/size; + + // do the pointwise multiply of a and b and also apply the scale + // factor in this loop too. + for (unsigned long i = 0; i < size; ++i) + { + a[i] = l*a[i]*b[i]; + } + + // Now compute the inverse fft of the pointwise multiplication of a and b. + // This is basically the result. We just have to take care of any carries + // that should happen. + ifft(a,size); + + // loop over the result and propagate any carries that need to take place. + // We will also be moving the resulting numbers into result->number at + // the same time. + uint64 carry = 0; + result->digits_used = 0; + int zeros = 0; + const uint32 len = lhs->digits_used + rhs->digits_used; + for (unsigned long i = 0; i < len; ++i) + { + uint64 num1 = static_cast(std::round(a[i*2].real())); + num1 += carry; + carry = 0; + if (num1 > 255) + { + carry = num1 >> 8; + num1 = (num1&0xFF); + } + + uint64 num2 = static_cast(std::round(a[i*2+1].real())); + num2 += carry; + carry = 0; + if (num2 > 255) + { + carry = num2 >> 8; + num2 = (num2&0xFF); + } + + // put the new number into its final place + num1 = (num2<<8) | num1; + result->number[i] = static_cast(num1); + + // keep track of the number of leading zeros + if (num1 == 0) + ++zeros; + else + zeros = 0; + ++(result->digits_used); + } + + // adjust digits_used so that it reflects the actual number + // of non-zero digits in our representation. + result->digits_used -= zeros; + + // if the result was zero then adjust the result accordingly + if (result->digits_used == 0) + { + // make result be zero + result->digits_used = 1; + *(result->number) = 0; + } + + // free all the temporary buffers + delete [] a; + delete [] b; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const + { + uint32 offset = shift_amount/16; + shift_amount &= 0xf; // same as shift_amount %= 16; + + uint16* r = result->number + data->digits_used + offset; // result + uint16* end = data->number; + uint16* s = end + data->digits_used; // source + const uint32 temp = 16 - shift_amount; + + *r = (*(--s) >> temp); + // set the number of digits used in the result + // if the upper bits from *s were zero then don't count this first word + if (*r == 0) + { + result->digits_used = data->digits_used + offset; + } + else + { + result->digits_used = data->digits_used + offset + 1; + } + --r; + + while (s != end) + { + *r = ((*s << shift_amount) | ( *(s-1) >> temp)); + --r; + --s; + } + *r = *s << shift_amount; + + // now zero the rest of the result + end = result->number; + while (r != end) + *(--r) = 0; + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + shift_right ( + const data_record* data, + data_record* result + ) const + { + + uint16* r = result->number; // result + uint16* s = data->number; // source + uint16* end = s + data->digits_used - 1; + + while (s != end) + { + *r = (*s >> 1) | (*(s+1) << 15); + ++r; + ++s; + } + *r = *s >> 1; + + + // calculate the new number for digits_used + if (*r == 0) + { + if (data->digits_used != 1) + result->digits_used = data->digits_used - 1; + else + result->digits_used = 1; + } + else + { + result->digits_used = data->digits_used; + } + + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const + { + uint32 lhs_digits_used = lhs->digits_used; + uint32 rhs_digits_used = rhs->digits_used; + + // if lhs is definitely less than rhs + if (lhs_digits_used < rhs_digits_used ) + return true; + // if lhs is definitely greater than rhs + else if (lhs_digits_used > rhs_digits_used) + return false; + else + { + uint16* end = lhs->number; + uint16* l = end + lhs_digits_used; + uint16* r = rhs->number + rhs_digits_used; + + while (l != end) + { + --l; + --r; + if (*l < *r) + return true; + else if (*l > *r) + return false; + } + + // at this point we know that they are equal + return false; + } + + } + +// ---------------------------------------------------------------------------------------- + + bool bigint_kernel_2:: + is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const + { + // if lhs and rhs are definitely not equal + if (lhs->digits_used != rhs->digits_used ) + { + return false; + } + else + { + uint16* l = lhs->number; + uint16* r = rhs->number; + uint16* end = l + lhs->digits_used; + + while (l != end) + { + if (*l != *r) + return false; + ++l; + ++r; + } + + // at this point we know that they are equal + return true; + } + + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + increment ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s + 1; + + // if there was no carry then break out of the loop + if (*d != 0) + { + dest->digits_used = source->digits_used; + + // copy the rest of the digits over to d + ++d; ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + + break; + } + + + ++s; + + // if we have hit the end of s and there was a carry up to this point + // then just make the next digit 1 and add one to the digits used + if (s == end) + { + ++d; + dest->digits_used = source->digits_used + 1; + *d = 1; + break; + } + + ++d; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + decrement ( + const data_record* source, + data_record* dest + ) const + { + uint16* s = source->number; + uint16* d = dest->number; + uint16* end = s + source->digits_used; + while (true) + { + *d = *s - 1; + + // if there was no carry then break out of the loop + if (*d != 0xFFFF) + { + // if we lost a digit in the subtraction + if (*d == 0 && s+1 == end) + { + if (source->digits_used == 1) + dest->digits_used = 1; + else + dest->digits_used = source->digits_used - 1; + } + else + { + dest->digits_used = source->digits_used; + } + break; + } + else + { + ++d; + ++s; + } + + } + + // copy the rest of the digits over to d + ++d; + ++s; + while (s != end) + { + *d = *s; + ++d; + ++s; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + fft ( + ct* data, + unsigned long len + ) const + { + const t pi2 = -2.0*3.1415926535897932384626433832795028841971693993751; + + const unsigned long half = len/2; + + std::vector twiddle_factors; + twiddle_factors.resize(half); + + // compute the complex root of unity w + const t temp = pi2/len; + ct w = ct(std::cos(temp),std::sin(temp)); + + ct w_pow = 1; + + // compute the twiddle factors + for (std::vector::size_type j = 0; j < twiddle_factors.size(); ++j) + { + twiddle_factors[j] = w_pow; + w_pow *= w; + } + + ct a, b; + + // now compute the decimation in frequency. This first + // outer loop loops log2(len) number of times + unsigned long skip = 1; + for (unsigned long step = half; step != 0; step >>= 1) + { + // do blocks of butterflies in this loop + for (unsigned long j = 0; j < len; j += step*2) + { + // do step butterflies + for (unsigned long k = 0; k < step; ++k) + { + const unsigned long a_idx = j+k; + const unsigned long b_idx = j+k+step; + a = data[a_idx] + data[b_idx]; + b = (data[a_idx] - data[b_idx])*twiddle_factors[k*skip]; + data[a_idx] = a; + data[b_idx] = b; + } + } + skip *= 2; + } + } + +// ---------------------------------------------------------------------------------------- + + void bigint_kernel_2:: + ifft( + ct* data, + unsigned long len + ) const + { + const t pi2 = 2.0*3.1415926535897932384626433832795028841971693993751; + + const unsigned long half = len/2; + + std::vector twiddle_factors; + twiddle_factors.resize(half); + + // compute the complex root of unity w + const t temp = pi2/len; + ct w = ct(std::cos(temp),std::sin(temp)); + + ct w_pow = 1; + + // compute the twiddle factors + for (std::vector::size_type j = 0; j < twiddle_factors.size(); ++j) + { + twiddle_factors[j] = w_pow; + w_pow *= w; + } + + ct a, b; + + // now compute the inverse decimation in frequency. This first + // outer loop loops log2(len) number of times + unsigned long skip = half; + for (unsigned long step = 1; step <= half; step <<= 1) + { + // do blocks of butterflies in this loop + for (unsigned long j = 0; j < len; j += step*2) + { + // do step butterflies + for (unsigned long k = 0; k < step; ++k) + { + const unsigned long a_idx = j+k; + const unsigned long b_idx = j+k+step; + data[b_idx] *= twiddle_factors[k*skip]; + a = data[a_idx] + data[b_idx]; + b = data[a_idx] - data[b_idx]; + data[a_idx] = a; + data[b_idx] = b; + } + } + skip /= 2; + } + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_BIGINT_KERNEL_2_CPp_ + diff --git a/dlib/bigint/bigint_kernel_2.h b/dlib/bigint/bigint_kernel_2.h new file mode 100644 index 0000000000000000000000000000000000000000..4c771cf372896aac8113a914f8d4627c6a6f5d98 --- /dev/null +++ b/dlib/bigint/bigint_kernel_2.h @@ -0,0 +1,569 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEl_2_ +#define DLIB_BIGINT_KERNEl_2_ + +#include "bigint_kernel_abstract.h" +#include "../algs.h" +#include "../serialize.h" +#include "../uintn.h" +#include +#include +#include +#include + +namespace dlib +{ + + class bigint_kernel_2 + { + /*! + INITIAL VALUE + slack == 25 + data->number[0] == 0 + data->size == slack + data->references == 1 + data->digits_used == 1 + + + CONVENTION + slack == the number of extra digits placed into the number when it is + created. the slack value should never be less than 1 + + data->number == pointer to an array of data->size uint16s. + data represents a string of base 65535 numbers with data[0] being + the least significant bit and data[data->digits_used-1] being the most + significant + + + NOTE: In the comments I will consider a word to be a 16 bit value + + + data->digits_used == the number of significant digits in the number. + data->digits_used tells us the number of used elements in the + data->number array so everything beyond data->number[data->digits_used-1] + is undefined + + data->references == the number of bigint_kernel_2 objects which refer + to this data_record + !*/ + + + struct data_record + { + + + explicit data_record( + uint32 size_ + ) : + size(size_), + number(new uint16[size_]), + references(1), + digits_used(1) + {*number = 0;} + /*! + ensures + - initializes *this to represent zero + !*/ + + data_record( + const data_record& item, + uint32 additional_size + ) : + size(item.digits_used + additional_size), + number(new uint16[size]), + references(1), + digits_used(item.digits_used) + { + uint16* source = item.number; + uint16* dest = number; + uint16* end = source + digits_used; + while (source != end) + { + *dest = *source; + ++dest; + ++source; + } + } + /*! + ensures + - *this is a copy of item except with + size == item.digits_used + additional_size + !*/ + + ~data_record( + ) + { + delete [] number; + } + + + const uint32 size; + uint16* number; + uint32 references; + uint32 digits_used; + + private: + // no copy constructor + data_record ( data_record&); + }; + + + // note that the second parameter is just there + // to resolve the ambiguity between this constructor and + // bigint_kernel_2(uint32) + explicit bigint_kernel_2 ( + data_record* data_, int + ): slack(25),data(data_) {} + /*! + ensures + - *this is initialized with data_ as its data member + !*/ + + public: + + bigint_kernel_2 ( + ); + + bigint_kernel_2 ( + uint32 value + ); + + bigint_kernel_2 ( + const bigint_kernel_2& item + ); + + virtual ~bigint_kernel_2 ( + ); + + const bigint_kernel_2 operator+ ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator+= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator- ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator-= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator* ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator*= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator/ ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator/= ( + const bigint_kernel_2& rhs + ); + + const bigint_kernel_2 operator% ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator%= ( + const bigint_kernel_2& rhs + ); + + bool operator < ( + const bigint_kernel_2& rhs + ) const; + + bool operator == ( + const bigint_kernel_2& rhs + ) const; + + bigint_kernel_2& operator= ( + const bigint_kernel_2& rhs + ); + + friend std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_2& rhs + ); + + friend std::istream& operator>> ( + std::istream& in, + bigint_kernel_2& rhs + ); + + bigint_kernel_2& operator++ ( + ); + + const bigint_kernel_2 operator++ ( + int + ); + + bigint_kernel_2& operator-- ( + ); + + const bigint_kernel_2 operator-- ( + int + ); + + friend const bigint_kernel_2 operator+ ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator+ ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator+= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator- ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator- ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator-= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator* ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator* ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator*= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator/ ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator/ ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator/= ( + uint16 rhs + ); + + friend const bigint_kernel_2 operator% ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend const bigint_kernel_2 operator% ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + bigint_kernel_2& operator%= ( + uint16 rhs + ); + + friend bool operator < ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + friend bool operator < ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + friend bool operator == ( + const bigint_kernel_2& lhs, + uint16 rhs + ); + + friend bool operator == ( + uint16 lhs, + const bigint_kernel_2& rhs + ); + + bigint_kernel_2& operator= ( + uint16 rhs + ); + + + void swap ( + bigint_kernel_2& item + ) { data_record* temp = data; data = item.data; item.data = temp; } + + + private: + + typedef double t; + typedef std::complex ct; + + void fft( + ct* data, + unsigned long len + ) const; + /*! + requires + - len == x^n for some integer n (i.e. len is a power of 2) + - len > 0 + ensures + - #data == the FT decimation in frequency of data + !*/ + + void ifft( + ct* data, + unsigned long len + ) const; + /*! + requires + - len == x^n for some integer n (i.e. len is a power of 2) + - len > 0 + ensures + - #data == the inverse decimation in frequency of data. + (i.e. the inverse of what fft(data,len,-1) does to data) + !*/ + + void long_add ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= max(lhs->digits_used,rhs->digits_used) + 1 + ensures + - result == lhs + rhs + !*/ + + void long_sub ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - lhs >= rhs + - result->size >= lhs->digits_used + ensures + - result == lhs - rhs + !*/ + + void long_div ( + const data_record* lhs, + const data_record* rhs, + data_record* result, + data_record* remainder + ) const; + /*! + requires + - rhs != 0 + - result->size >= lhs->digits_used + - remainder->size >= lhs->digits_used + - each parameter is unique (i.e. lhs != result, lhs != remainder, etc.) + ensures + - result == lhs / rhs + - remainder == lhs % rhs + !*/ + + void long_mul ( + const data_record* lhs, + const data_record* rhs, + data_record* result + ) const; + /*! + requires + - result->size >= lhs->digits_used + rhs->digits_used + - result != lhs + - result != rhs + ensures + - result == lhs * rhs + !*/ + + void short_add ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->size + 1 + ensures + - result == data + value + !*/ + + void short_sub ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - data >= value + - result->size >= data->digits_used + ensures + - result == data - value + !*/ + + void short_mul ( + const data_record* data, + uint16 value, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + 1 + ensures + - result == data * value + !*/ + + void short_div ( + const data_record* data, + uint16 value, + data_record* result, + uint16& remainder + ) const; + /*! + requires + - value != 0 + - result->size >= data->digits_used + ensures + - result = data*value + - remainder = data%value + !*/ + + void shift_left ( + const data_record* data, + data_record* result, + uint32 shift_amount + ) const; + /*! + requires + - result->size >= data->digits_used + shift_amount/8 + 1 + ensures + - result == data << shift_amount + !*/ + + void shift_right ( + const data_record* data, + data_record* result + ) const; + /*! + requires + - result->size >= data->digits_used + ensures + - result == data >> 1 + !*/ + + bool is_less_than ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs < rhs + - returns false otherwise + !*/ + + bool is_equal_to ( + const data_record* lhs, + const data_record* rhs + ) const; + /*! + ensures + - returns true if lhs == rhs + - returns false otherwise + !*/ + + void increment ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + - dest->size >= source->digits_used + 1 + ensures + - dest = source + 1 + !*/ + + void decrement ( + const data_record* source, + data_record* dest + ) const; + /*! + requires + source != 0 + ensuers + dest = source - 1 + !*/ + + // member data + const uint32 slack; + data_record* data; + + + + }; + + inline void swap ( + bigint_kernel_2& a, + bigint_kernel_2& b + ) { a.swap(b); } + + inline void serialize ( + const bigint_kernel_2& item, + std::ostream& out + ) + { + std::ios::fmtflags oldflags = out.flags(); + out << item << ' '; + out.flags(oldflags); + if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); + } + + inline void deserialize ( + bigint_kernel_2& item, + std::istream& in + ) + { + std::ios::fmtflags oldflags = in.flags(); + in >> item; + in.flags(oldflags); + if (in.get() != ' ') + { + item = 0; + throw serialization_error("Error deserializing object of type bigint_kernel_c"); + } + } + + inline bool operator> (const bigint_kernel_2& a, const bigint_kernel_2& b) { return b < a; } + inline bool operator!= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a == b); } + inline bool operator<= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(b < a); } + inline bool operator>= (const bigint_kernel_2& a, const bigint_kernel_2& b) { return !(a < b); } + +} + +#ifdef NO_MAKEFILE +#include "bigint_kernel_2.cpp" +#endif + +#endif // DLIB_BIGINT_KERNEl_2_ + diff --git a/dlib/bigint/bigint_kernel_abstract.h b/dlib/bigint/bigint_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..644129c4be99b9bb82164aff1cf6870dc8674080 --- /dev/null +++ b/dlib/bigint/bigint_kernel_abstract.h @@ -0,0 +1,670 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BIGINT_KERNEl_ABSTRACT_ +#ifdef DLIB_BIGINT_KERNEl_ABSTRACT_ + +#include +#include "../algs.h" +#include "../serialize.h" +#include "../uintn.h" + +namespace dlib +{ + + class bigint + { + /*! + INITIAL VALUE + *this == 0 + + WHAT THIS OBJECT REPRESENTS + This object represents an arbitrary precision unsigned integer + + the following operators are supported: + operator + + operator += + operator - + operator -= + operator * + operator *= + operator / + operator /= + operator % + operator %= + operator == + operator < + operator = + operator << (for writing to ostreams) + operator >> (for reading from istreams) + operator++ // pre increment + operator++(int) // post increment + operator-- // pre decrement + operator--(int) // post decrement + + + the other comparison operators(>, !=, <=, and >=) are + available and come from the templates in dlib::relational_operators + + THREAD SAFETY + bigint may be reference counted so it is very unthread safe. + use with care in a multithreaded program + + !*/ + + public: + + bigint ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + if this is thrown the bigint will be unusable but + will not leak memory + !*/ + + bigint ( + uint32 value + ); + /*! + requires + - value <= (2^32)-1 + ensures + - #*this is properly initialized + - #*this == value + throws + - std::bad_alloc + if this is thrown the bigint will be unusable but + will not leak memory + !*/ + + bigint ( + const bigint& item + ); + /*! + ensures + - #*this is properly initialized + - #*this == value + throws + - std::bad_alloc + if this is thrown the bigint will be unusable but + will not leak memory + !*/ + + virtual ~bigint ( + ); + /*! + ensures + - all resources associated with #*this have been released + !*/ + + const bigint operator+ ( + const bigint& rhs + ) const; + /*! + ensures + - returns the result of adding rhs to *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator+= ( + const bigint& rhs + ); + /*! + ensures + - #*this == *this + rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator- ( + const bigint& rhs + ) const; + /*! + requires + - *this >= rhs + ensures + - returns the result of subtracting rhs from *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator-= ( + const bigint& rhs + ); + /*! + requires + - *this >= rhs + ensures + - #*this == *this - rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator* ( + const bigint& rhs + ) const; + /*! + ensures + - returns the result of multiplying *this and rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator*= ( + const bigint& rhs + ); + /*! + ensures + - #*this == *this * rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator/ ( + const bigint& rhs + ) const; + /*! + requires + - rhs != 0 + ensures + - returns the result of dividing *this by rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator/= ( + const bigint& rhs + ); + /*! + requires + - rhs != 0 + ensures + - #*this == *this / rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator% ( + const bigint& rhs + ) const; + /*! + requires + - rhs != 0 + ensures + - returns the result of *this mod rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator%= ( + const bigint& rhs + ); + /*! + requires + - rhs != 0 + ensures + - #*this == *this % rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bool operator < ( + const bigint& rhs + ) const; + /*! + ensures + - returns true if *this is less than rhs + - returns false otherwise + !*/ + + bool operator == ( + const bigint& rhs + ) const; + /*! + ensures + - returns true if *this and rhs represent the same number + - returns false otherwise + !*/ + + bigint& operator= ( + const bigint& rhs + ); + /*! + ensures + - #*this == rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + + friend std::ostream& operator<< ( + std::ostream& out, + const bigint& rhs + ); + /*! + ensures + - the number in *this has been written to #out as a base ten number + throws + - std::bad_alloc + if this function throws then it has no effect (nothing + is written to out) + !*/ + + friend std::istream& operator>> ( + std::istream& in, + bigint& rhs + ); + /*! + ensures + - reads a number from in and puts it into #*this + - if (there is no positive base ten number on the input stream ) then + - #in.fail() == true + throws + - std::bad_alloc + if this function throws the value in rhs is undefined and some + characters may have been read from in. rhs is still usable though, + its value is just unknown. + !*/ + + + bigint& operator++ ( + ); + /*! + ensures + - #*this == *this + 1 + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator++ ( + int + ); + /*! + ensures + - #*this == *this + 1 + - returns *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator-- ( + ); + /*! + requires + - *this != 0 + ensures + - #*this == *this - 1 + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + const bigint operator-- ( + int + ); + /*! + requires + - *this != 0 + ensures + - #*this == *this - 1 + - returns *this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + void swap ( + bigint& item + ); + /*! + ensures + - swaps *this and item + !*/ + + + // ------------------------------------------------------------------ + // ---- The following functions are identical to the above ----- + // ---- but take uint16 as one of their arguments. They --- + // ---- exist only to allow for a more efficient implementation --- + // ------------------------------------------------------------------ + + + friend const bigint operator+ ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns the result of adding rhs to lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator+ ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns the result of adding rhs to lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator+= ( + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - #*this == *this + rhs + - returns #this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator- ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs >= rhs + - lhs <= 65535 + ensures + - returns the result of subtracting rhs from lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator- ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - lhs >= rhs + - rhs <= 65535 + ensures + - returns the result of subtracting rhs from lhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator-= ( + uint16 rhs + ); + /*! + requires + - *this >= rhs + - rhs <= 65535 + ensures + - #*this == *this - rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator* ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns the result of multiplying lhs and rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator* ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns the result of multiplying lhs and rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator*= ( + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - #*this == *this * rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator/ ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - rhs != 0 + - lhs <= 65535 + ensures + - returns the result of dividing lhs by rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator/ ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - returns the result of dividing lhs by rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator/= ( + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - #*this == *this / rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator% ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - rhs != 0 + - lhs <= 65535 + ensures + - returns the result of lhs mod rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + friend const bigint operator% ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - returns the result of lhs mod rhs + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + bigint& operator%= ( + uint16 rhs + ); + /*! + requires + - rhs != 0 + - rhs <= 65535 + ensures + - #*this == *this % rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + + friend bool operator < ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns true if lhs is less than rhs + - returns false otherwise + !*/ + + friend bool operator < ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns true if lhs is less than rhs + - returns false otherwise + !*/ + + friend bool operator == ( + const bigint& lhs, + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - returns true if lhs and rhs represent the same number + - returns false otherwise + !*/ + + friend bool operator == ( + uint16 lhs, + const bigint& rhs + ); + /*! + requires + - lhs <= 65535 + ensures + - returns true if lhs and rhs represent the same number + - returns false otherwise + !*/ + + bigint& operator= ( + uint16 rhs + ); + /*! + requires + - rhs <= 65535 + ensures + - #*this == rhs + - returns #*this + throws + - std::bad_alloc + if this function throws then it has no effect + !*/ + + }; + + inline void swap ( + bigint& a, + bigint& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + void serialize ( + const bigint& item, + std::istream& in + ); + /*! + provides serialization support + !*/ + + void deserialize ( + bigint& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ + + inline bool operator> (const bigint& a, const bigint& b) { return b < a; } + inline bool operator!= (const bigint& a, const bigint& b) { return !(a == b); } + inline bool operator<= (const bigint& a, const bigint& b) { return !(b < a); } + inline bool operator>= (const bigint& a, const bigint& b) { return !(a < b); } +} + +#endif // DLIB_BIGINT_KERNEl_ABSTRACT_ + diff --git a/dlib/bigint/bigint_kernel_c.h b/dlib/bigint/bigint_kernel_c.h new file mode 100644 index 0000000000000000000000000000000000000000..47089730af5472af32ddc40f43a8ff9df3c8033d --- /dev/null +++ b/dlib/bigint/bigint_kernel_c.h @@ -0,0 +1,1140 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIGINT_KERNEl_C_ +#define DLIB_BIGINT_KERNEl_C_ + +#include "bigint_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + class bigint_kernel_c + { + bigint_base data; + + explicit bigint_kernel_c ( + const bigint_base& item + ) : data(item) {} + + public: + + + bigint_kernel_c ( + ); + + bigint_kernel_c ( + uint32 value + ); + + bigint_kernel_c ( + const bigint_kernel_c& item + ); + + ~bigint_kernel_c ( + ); + + const bigint_kernel_c operator+ ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator+= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator- ( + const bigint_kernel_c& rhs + ) const; + bigint_kernel_c& operator-= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator* ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator*= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator/ ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator/= ( + const bigint_kernel_c& rhs + ); + + const bigint_kernel_c operator% ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator%= ( + const bigint_kernel_c& rhs + ); + + bool operator < ( + const bigint_kernel_c& rhs + ) const; + + bool operator == ( + const bigint_kernel_c& rhs + ) const; + + bigint_kernel_c& operator= ( + const bigint_kernel_c& rhs + ); + + template + friend std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_c& rhs + ); + + template + friend std::istream& operator>> ( + std::istream& in, + bigint_kernel_c& rhs + ); + + bigint_kernel_c& operator++ ( + ); + + const bigint_kernel_c operator++ ( + int + ); + + bigint_kernel_c& operator-- ( + ); + + const bigint_kernel_c operator-- ( + int + ); + + template + friend const bigint_kernel_c operator+ ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator+ ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator+= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator- ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator- ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator-= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator* ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator* ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator*= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator/ ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator/ ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator/= ( + uint16 rhs + ); + + template + friend const bigint_kernel_c operator% ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend const bigint_kernel_c operator% ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + bigint_kernel_c& operator%= ( + uint16 rhs + ); + + template + friend bool operator < ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + template + friend bool operator < ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + template + friend bool operator == ( + const bigint_kernel_c& lhs, + uint16 rhs + ); + + template + friend bool operator == ( + uint16 lhs, + const bigint_kernel_c& rhs + ); + + bigint_kernel_c& operator= ( + uint16 rhs + ); + + + void swap ( + bigint_kernel_c& item + ) { data.swap(item.data); } + + }; + + template < + typename bigint_base + > + void swap ( + bigint_kernel_c& a, + bigint_kernel_c& b + ) { a.swap(b); } + + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + inline void serialize ( + const bigint_kernel_c& item, + std::ostream& out + ) + { + std::ios::fmtflags oldflags = out.flags(); + out << item << ' '; + out.flags(oldflags); + if (!out) throw serialization_error("Error serializing object of type bigint_kernel_c"); + } + + template < + typename bigint_base + > + inline void deserialize ( + bigint_kernel_c& item, + std::istream& in + ) + { + std::ios::fmtflags oldflags = in.flags(); + in >> item; + in.flags(oldflags); + if (in.get() != ' ') + { + item = 0; + throw serialization_error("Error deserializing object of type bigint_kernel_c"); + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + bigint_kernel_c ( + ) + {} + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + bigint_kernel_c ( + uint32 value + ) : + data(value) + { + // make sure requires clause is not broken + DLIB_CASSERT( value <= 0xFFFFFFFF , + "\tbigint::bigint(uint16)" + << "\n\t value must be <= (2^32)-1" + << "\n\tthis: " << this + << "\n\tvalue: " << value + ); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + bigint_kernel_c ( + const bigint_kernel_c& item + ) : + data(item.data) + {} + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c:: + ~bigint_kernel_c ( + ) + {} + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator+ ( + const bigint_kernel_c& rhs + ) const + { + return bigint_kernel_c(data + rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator+= ( + const bigint_kernel_c& rhs + ) + { + data += rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator- ( + const bigint_kernel_c& rhs + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this < rhs), + "\tconst bigint bigint::operator-(const bigint&)" + << "\n\t *this should not be less than rhs" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + // call the real function + return bigint_kernel_c(data-rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator-= ( + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this < rhs), + "\tbigint& bigint::operator-=(const bigint&)" + << "\n\t *this should not be less than rhs" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + // call the real function + data -= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator* ( + const bigint_kernel_c& rhs + ) const + { + return bigint_kernel_c(data * rhs.data ); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator*= ( + const bigint_kernel_c& rhs + ) + { + data *= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator/ ( + const bigint_kernel_c& rhs + ) const + { + //make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tconst bigint bigint::operator/(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + return bigint_kernel_c(data/rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator/= ( + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tbigint& bigint::operator/=(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + data /= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator% ( + const bigint_kernel_c& rhs + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tconst bigint bigint::operator%(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + return bigint_kernel_c(data%rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator%= ( + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0), + "\tbigint& bigint::operator%=(const bigint&)" + << "\n\t can't divide by zero" + << "\n\tthis: " << this + ); + + // call the real function + data %= rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool bigint_kernel_c:: + operator < ( + const bigint_kernel_c& rhs + ) const + { + return data < rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool bigint_kernel_c:: + operator == ( + const bigint_kernel_c& rhs + ) const + { + return data == rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator= ( + const bigint_kernel_c& rhs + ) + { + data = rhs.data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + std::ostream& operator<< ( + std::ostream& out, + const bigint_kernel_c& rhs + ) + { + out << rhs.data; + return out; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + std::istream& operator>> ( + std::istream& in, + bigint_kernel_c& rhs + ) + { + in >> rhs.data; + return in; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator++ ( + ) + { + ++data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator++ ( + int + ) + { + return bigint_kernel_c(data++); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator-- ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this == 0), + "\tbigint& bigint::operator--()" + << "\n\t *this to subtract from *this it must not be zero to begin with" + << "\n\tthis: " << this + ); + + // call the real function + --data; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c bigint_kernel_c:: + operator-- ( + int + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(*this == 0), + "\tconst bigint bigint::operator--(int)" + << "\n\t *this to subtract from *this it must not be zero to begin with" + << "\n\tthis: " << this + ); + + // call the real function + return bigint_kernel_c(data--); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator+ ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tconst bigint operator+(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(static_cast(lhs)+rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator+ ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tconst bigint operator+(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(lhs.data+static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator+= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbigint& bigint::operator+=(uint16)" + << "\n\t rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + data += rhs; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator- ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( !(static_cast(lhs) < rhs) && lhs <= 65535, + "\tconst bigint operator-(uint16,const bigint&)" + << "\n\t lhs must be greater than or equal to rhs and lhs <= 65535" + << "\n\tlhs: " << lhs + << "\n\trhs: " << rhs + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + // call the real function + return bigint_kernel_c(static_cast(lhs)-rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator- ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(lhs < static_cast(rhs)) && rhs <= 65535, + "\tconst bigint operator-(const bigint&,uint16)" + << "\n\t lhs must be greater than or equal to rhs and rhs <= 65535" + << "\n\tlhs: " << lhs + << "\n\trhs: " << rhs + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + ); + + // call the real function + return bigint_kernel_c(lhs.data-static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator-= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(*this < static_cast(rhs)) && rhs <= 65535, + "\tbigint& bigint::operator-=(uint16)" + << "\n\t *this must not be less than rhs and rhs <= 65535" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + // call the real function + data -= static_cast(rhs); + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator* ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tconst bigint operator*(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(lhs*rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator* ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tconst bigint operator*(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return bigint_kernel_c(lhs.data*rhs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator*= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\t bigint bigint::operator*=(uint16)" + << "\n\t rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\t*this: " << *this + << "\n\trhs: " << rhs + ); + + data *= static_cast(rhs); + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator/ ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && lhs <= 65535, + "\tconst bigint operator/(uint16,const bigint&)" + << "\n\t you can't divide by zero and lhs <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\tlhs: " << lhs + ); + + // call the real function + return bigint_kernel_c(lhs/rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator/ ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, + "\tconst bigint operator/(const bigint&,uint16)" + << "\n\t you can't divide by zero and rhs <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\trhs: " << rhs + ); + + // call the real function + return bigint_kernel_c(lhs.data/static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator/= ( + uint16 rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && static_cast(rhs) <= 65535, + "\tbigint& bigint::operator/=(uint16)" + << "\n\t you can't divide by zero and rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\trhs: " << rhs + ); + + // call the real function + data /= rhs; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator% ( + uint16 lhs, + const bigint_kernel_c& rhs + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && static_cast(lhs) <= 65535, + "\tconst bigint operator%(uint16,const bigint&)" + << "\n\t you can't divide by zero and lhs must be <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\tlhs: " << lhs + ); + + // call the real function + return bigint_kernel_c(lhs%rhs.data); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + const bigint_kernel_c operator% ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, + "\tconst bigint operator%(const bigint&,uint16)" + << "\n\t you can't divide by zero and rhs must be <= 65535" + << "\n\t&lhs: " << &lhs + << "\n\t&rhs: " << &rhs + << "\n\trhs: " << rhs + ); + + // call the real function + return bigint_kernel_c(lhs.data%static_cast(rhs)); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator%= ( + uint16 r + ) + { + + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( !(rhs == 0) && rhs <= 65535, + "\tbigint& bigint::operator%=(uint16)" + << "\n\t you can't divide by zero and rhs must be <= 65535" + << "\n\tthis: " << this + << "\n\trhs: " << rhs + ); + + // call the real function + data %= rhs; + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator < ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tbool operator<(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return static_cast(lhs) < rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator < ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbool operator<(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return lhs.data < static_cast(rhs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator == ( + const bigint_kernel_c& lhs, + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbool operator==(const bigint&, uint16)" + << "\n\t rhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return lhs.data == static_cast(rhs); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bool operator == ( + uint16 l, + const bigint_kernel_c& rhs + ) + { + uint32 lhs = l; + // make sure requires clause is not broken + DLIB_CASSERT( lhs <= 65535, + "\tbool operator==(uint16, const bigint&)" + << "\n\t lhs must be <= 65535" + << "\n\trhs: " << rhs + << "\n\tlhs: " << lhs + ); + + return static_cast(lhs) == rhs.data; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bigint_base + > + bigint_kernel_c& bigint_kernel_c:: + operator= ( + uint16 r + ) + { + uint32 rhs = r; + // make sure requires clause is not broken + DLIB_CASSERT( rhs <= 65535, + "\tbigint bigint::operator=(uint16)" + << "\n\t rhs must be <= 65535" + << "\n\t*this: " << *this + << "\n\tthis: " << this + << "\n\tlhs: " << rhs + ); + + data = static_cast(rhs); + return *this; + } + +// ---------------------------------------------------------------------------------------- + + template < typename bigint_base > + inline bool operator> (const bigint_kernel_c& a, const bigint_kernel_c& b) { return b < a; } + template < typename bigint_base > + inline bool operator!= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(a == b); } + template < typename bigint_base > + inline bool operator<= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(b < a); } + template < typename bigint_base > + inline bool operator>= (const bigint_kernel_c& a, const bigint_kernel_c& b) { return !(a < b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIGINT_KERNEl_C_ + diff --git a/dlib/binary_search_tree.h b/dlib/binary_search_tree.h new file mode 100644 index 0000000000000000000000000000000000000000..5273e8ce9230544f31d29c913f4b75518474a525 --- /dev/null +++ b/dlib/binary_search_tree.h @@ -0,0 +1,50 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREe_ +#define DLIB_BINARY_SEARCH_TREe_ + + +#include "binary_search_tree/binary_search_tree_kernel_1.h" +#include "binary_search_tree/binary_search_tree_kernel_2.h" +#include "binary_search_tree/binary_search_tree_kernel_c.h" + + +#include "algs.h" +#include + + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class binary_search_tree + { + binary_search_tree() {} + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef binary_search_tree_kernel_1 + kernel_1a; + typedef binary_search_tree_kernel_c + kernel_1a_c; + + + // kernel_2a + typedef binary_search_tree_kernel_2 + kernel_2a; + typedef binary_search_tree_kernel_c + kernel_2a_c; + + }; +} + +#endif // DLIB_BINARY_SEARCH_TREe_ + diff --git a/dlib/binary_search_tree/binary_search_tree_kernel_1.h b/dlib/binary_search_tree/binary_search_tree_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..f0a0cfbd820b93431dd092d5d4667ffe306509ed --- /dev/null +++ b/dlib/binary_search_tree/binary_search_tree_kernel_1.h @@ -0,0 +1,2064 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_1_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_1_ + +#include "binary_search_tree_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager, + typename compare = std::less + > + class binary_search_tree_kernel_1 : public enumerable >, + public asc_pair_remover + { + + /*! + INITIAL VALUE + tree_size == 0 + tree_root == 0 + tree_height == 0 + at_start_ == true + current_element == 0 + stack == array of 50 node pointers + stack_pos == 0 + + + CONVENTION + tree_size == size() + tree_height == height() + + stack[stack_pos-1] == pop() + + current_element_valid() == (current_element != 0) + if (current_element_valid()) then + element() == current_element->d and current_element->r + at_start_ == at_start() + if (current_element != 0 && current_element != tree_root) then + stack[stack_pos-1] == the parent of the node pointed to by current_element + + if (tree_size != 0) + tree_root == pointer to the root node of the binary search tree + else + tree_root == 0 + + + for all nodes: + { + left points to the left subtree or 0 if there is no left subtree and + right points to the right subtree or 0 if there is no right subtree and + all elements in a left subtree are <= the root and + all elements in a right subtree are >= the root and + d is the item in the domain of *this contained in the node + r is the item in the range of *this contained in the node + balance: + balance == 0 if both subtrees have the same height + balance == -1 if the left subtree has a height that is greater + than the height of the right subtree by 1 + balance == 1 if the right subtree has a height that is greater + than the height of the left subtree by 1 + for all trees: + the height of the left and right subtrees differ by at most one + } + + !*/ + + class node + { + public: + node* left; + node* right; + domain d; + range r; + signed char balance; + }; + + class mpair : public map_pair + { + public: + const domain* d; + range* r; + + const domain& key( + ) const { return *d; } + + const range& value( + ) const { return *r; } + + range& value( + ) { return *r; } + }; + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + binary_search_tree_kernel_1( + ) : + tree_size(0), + tree_root(0), + current_element(0), + tree_height(0), + at_start_(true), + stack_pos(0), + stack(ppool.allocate_array(50)) + { + } + + virtual ~binary_search_tree_kernel_1( + ); + + inline void clear( + ); + + inline short height ( + ) const; + + inline unsigned long count ( + const domain& item + ) const; + + inline void add ( + domain& d, + range& r + ); + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& item + ); + + inline const range* operator[] ( + const domain& item + ) const; + + inline range* operator[] ( + const domain& item + ); + + inline void swap ( + binary_search_tree_kernel_1& item + ); + + // function from the asc_pair_remover interface + void remove_any ( + domain& d, + range& r + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + const map_pair& element ( + ) const; + + map_pair& element ( + ); + + bool move_next ( + ) const; + + void remove_last_in_order ( + domain& d, + range& r + ); + + void remove_current_element ( + domain& d, + range& r + ); + + void position_enumerator ( + const domain& d + ) const; + + private: + + + inline void rotate_left ( + node*& t + ); + /*! + requires + - t->balance == 2 + - t->right->balance == 0 or 1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t->balance is between 1 and -1 + - #t now has a height smaller by 1 if #t->balance == 0 + !*/ + + inline void rotate_right ( + node*& t + ); + /*! + requires + - t->balance == -2 + - t->left->balance == 0 or -1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t->balance is between 1 and -1 + - #t now has a height smaller by 1 if #t->balance == 0 + + !*/ + + inline void double_rotate_right ( + node*& t + ); + /*! + requires + - t->balance == -2 + - t->left->balance == 1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t now has a balance of 0 + - #t now has a height smaller by 1 + !*/ + + inline void double_rotate_left ( + node*& t + ); + /*! + requires + - t->balance == 2 + - t->right->balance == -1 + - t == reference to the pointer in t's parent node that points to t + ensures + - #t is still a binary search tree + - #t now has a balance of 0 + - #t now has a height smaller by 1 + !*/ + + bool remove_biggest_element_in_tree ( + node*& t, + domain& d, + range& r + ); + /*! + requires + - t != 0 (i.e. there must be something in the tree to remove) + - t == reference to the pointer in t's parent node that points to t + ensures + - the biggest node in t has been removed + - the biggest node domain element in t has been put into #d + - the biggest node range element in t has been put into #r + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + bool remove_least_element_in_tree ( + node*& t, + domain& d, + range& r + ); + /*! + requires + - t != 0 (i.e. there must be something in the tree to remove) + - t == reference to the pointer in t's parent node that points to t + ensures + - the least node in t has been removed + - the least node domain element in t has been put into #d + - the least node range element in t has been put into #r + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + bool add_to_tree ( + node*& t, + domain& d, + range& r + ); + /*! + requires + - t == reference to the pointer in t's parent node that points to t + ensures + - the mapping (d --> r) has been added to #t + - #d and #r have initial values for their types + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has grown by one + !*/ + + bool remove_from_tree ( + node*& t, + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - return_reference(t,d) != 0 + - t == reference to the pointer in t's parent node that points to t + ensures + - #d_copy is equivalent to d + - an element in t equivalent to d has been removed and swapped + into #d_copy and its associated range object has been + swapped into #r + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + bool remove_from_tree ( + node*& t, + const domain& item + ); + /*! + requires + - return_reference(t,item) != 0 + - t == reference to the pointer in t's parent node that points to t + ensures + - an element in t equivalent to item has been removed + - #t is still a binary search tree + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + const range* return_reference ( + const node* t, + const domain& d + ) const; + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + range* return_reference ( + node* t, + const domain& d + ); + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + + inline bool keep_node_balanced ( + node*& t + ); + /*! + requires + - t != 0 + - t == reference to the pointer in t's parent node that points to t + ensures + - if (t->balance is < 1 or > 1) then + - keep_node_balanced() will ensure that #t->balance == 0, -1, or 1 + - #t is still a binary search tree + - returns true if it made the tree one height shorter + - returns false if it didn't change the height + !*/ + + + unsigned long get_count ( + const domain& item, + node* tree_root + ) const; + /*! + requires + - tree_root == the root of a binary search tree or 0 + ensures + - if (tree_root == 0) then + - returns 0 + - else + - returns the number of elements in tree_root that are + equivalent to item + !*/ + + + void delete_tree ( + node* t + ); + /*! + requires + - t != 0 + ensures + - deallocates the node pointed to by t and all of t's left and right children + !*/ + + + void push ( + node* n + ) const { stack[stack_pos] = n; ++stack_pos; } + /*! + ensures + - pushes n onto the stack + !*/ + + + node* pop ( + ) const { --stack_pos; return stack[stack_pos]; } + /*! + ensures + - pops the top of the stack and returns it + !*/ + + + + bool fix_stack ( + node* t, + unsigned char depth = 0 + ); + /*! + requires + - current_element != 0 + - depth == 0 + - t == tree_root + ensures + - makes the stack contain the correct set of parent pointers. + also adjusts stack_pos so it is correct. + - #t is still a binary search tree + !*/ + + bool remove_current_element_from_tree ( + node*& t, + domain& d, + range& r, + unsigned long cur_stack_pos = 1 + ); + /*! + requires + - t == tree_root + - cur_stack_pos == 1 + - current_element != 0 + ensures + - removes the data in the node given by current_element and swaps it into + #d and #r. + - #t is still a binary search tree + - the enumerator is advances on to the next element but its stack is + potentially corrupted. so you must call fix_stack(tree_root) to fix + it. + - returns false if the height of the tree has not changed + - returns true if the height of the tree has shrunk by one + !*/ + + + // data members + + mutable mpair p; + unsigned long tree_size; + node* tree_root; + mutable node* current_element; + typename mem_manager::template rebind::other pool; + typename mem_manager::template rebind::other ppool; + short tree_height; + mutable bool at_start_; + mutable unsigned char stack_pos; + mutable node** stack; + compare comp; + + // restricted functions + binary_search_tree_kernel_1(binary_search_tree_kernel_1&); + binary_search_tree_kernel_1& operator=(binary_search_tree_kernel_1&); + + + }; + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + inline void swap ( + binary_search_tree_kernel_1& a, + binary_search_tree_kernel_1& b + ) { a.swap(b); } + + + + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void deserialize ( + binary_search_tree_kernel_1& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + domain d; + range r; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + deserialize(r,in); + item.add(d,r); + } + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type binary_search_tree_kernel_1"); + } + } + + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + binary_search_tree_kernel_1:: + ~binary_search_tree_kernel_1 ( + ) + { + ppool.deallocate_array(stack); + if (tree_size != 0) + { + delete_tree(tree_root); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + clear ( + ) + { + if (tree_size > 0) + { + delete_tree(tree_root); + tree_root = 0; + tree_size = 0; + tree_height = 0; + } + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + size_t binary_search_tree_kernel_1:: + size ( + ) const + { + return tree_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + short binary_search_tree_kernel_1:: + height ( + ) const + { + return tree_height; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_1:: + count ( + const domain& item + ) const + { + return get_count(item,tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + add ( + domain& d, + range& r + ) + { + tree_height += add_to_tree(tree_root,d,r); + ++tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + tree_height -= remove_from_tree(tree_root,d,d_copy,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + destroy ( + const domain& item + ) + { + tree_height -= remove_from_tree(tree_root,item); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove_any ( + domain& d, + range& r + ) + { + tree_height -= remove_least_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_1:: + operator[] ( + const domain& item + ) + { + return return_reference(tree_root,item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_1:: + operator[] ( + const domain& item + ) const + { + return return_reference(tree_root,item); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + swap ( + binary_search_tree_kernel_1& item + ) + { + pool.swap(item.pool); + ppool.swap(item.ppool); + exchange(p,item.p); + exchange(stack,item.stack); + exchange(stack_pos,item.stack_pos); + exchange(comp,item.comp); + + + node* tree_root_temp = item.tree_root; + unsigned long tree_size_temp = item.tree_size; + short tree_height_temp = item.tree_height; + node* current_element_temp = item.current_element; + bool at_start_temp = item.at_start_; + + item.tree_root = tree_root; + item.tree_size = tree_size; + item.tree_height = tree_height; + item.current_element = current_element; + item.at_start_ = at_start_; + + tree_root = tree_root_temp; + tree_size = tree_size_temp; + tree_height = tree_height_temp; + current_element = current_element_temp; + at_start_ = at_start_temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove_last_in_order ( + domain& d, + range& r + ) + { + tree_height -= remove_biggest_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + remove_current_element ( + domain& d, + range& r + ) + { + tree_height -= remove_current_element_from_tree(tree_root,d,r); + --tree_size; + + // fix the enumerator stack if we need to + if (current_element) + fix_stack(tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + position_enumerator ( + const domain& d + ) const + { + // clear the enumerator state and make sure the stack is empty + reset(); + at_start_ = false; + node* t = tree_root; + bool went_left = false; + while (t != 0) + { + if ( comp(d , t->d) ) + { + push(t); + // if item is on the left then look in left + t = t->left; + went_left = true; + } + else if (comp(t->d , d)) + { + push(t); + // if item is on the right then look in right + t = t->right; + went_left = false; + } + else + { + current_element = t; + return; + } + } + + // if we didn't find any matches but there might be something after the + // d in this tree. + if (stack_pos > 0) + { + current_element = pop(); + // if we went left from this node then this node is the next + // biggest. + if (went_left) + { + return; + } + else + { + move_next(); + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + reset ( + ) const + { + at_start_ = true; + current_element = 0; + stack_pos = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + current_element_valid ( + ) const + { + return (current_element != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const map_pair& binary_search_tree_kernel_1:: + element ( + ) const + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + map_pair& binary_search_tree_kernel_1:: + element ( + ) + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + move_next ( + ) const + { + // if we haven't started iterating yet + if (at_start_) + { + at_start_ = false; + if (tree_size == 0) + { + return false; + } + else + { + // find the first element in the tree + current_element = tree_root; + node* temp = current_element->left; + while (temp != 0) + { + push(current_element); + current_element = temp; + temp = current_element->left; + } + return true; + } + } + else + { + if (current_element == 0) + { + return false; + } + else + { + node* temp; + bool went_up; // true if we went up the tree from a child node to parent + bool from_left = false; // true if we went up and were coming from a left child node + // find the next element in the tree + if (current_element->right != 0) + { + // go right and down + temp = current_element; + push(current_element); + current_element = temp->right; + went_up = false; + } + else + { + // go up to the parent if we can + if (current_element == tree_root) + { + // in this case we have iterated over all the element of the tree + current_element = 0; + return false; + } + went_up = true; + node* parent = pop(); + + + from_left = (parent->left == current_element); + // go up to parent + current_element = parent; + } + + + while (true) + { + if (went_up) + { + if (from_left) + { + // in this case we have found the next node + break; + } + else + { + if (current_element == tree_root) + { + // in this case we have iterated over all the elements + // in the tree + current_element = 0; + return false; + } + // we should go up + node* parent = pop(); + from_left = (parent->left == current_element); + current_element = parent; + } + } + else + { + // we just went down to a child node + if (current_element->left != 0) + { + // go left + went_up = false; + temp = current_element; + push(current_element); + current_element = temp->left; + } + else + { + // if there is no left child then we have found the next node + break; + } + } + } + + return true; + } + } + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + delete_tree ( + node* t + ) + { + if (t->left != 0) + delete_tree(t->left); + if (t->right != 0) + delete_tree(t->right); + pool.deallocate(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + rotate_left ( + node*& t + ) + { + + // set the new balance numbers + if (t->right->balance == 1) + { + t->balance = 0; + t->right->balance = 0; + } + else + { + t->balance = 1; + t->right->balance = -1; + } + + // perform the rotation + node* temp = t->right; + t->right = temp->left; + temp->left = t; + t = temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + rotate_right ( + node*& t + ) + { + // set the new balance numbers + if (t->left->balance == -1) + { + t->balance = 0; + t->left->balance = 0; + } + else + { + t->balance = -1; + t->left->balance = 1; + } + + // preform the rotation + node* temp = t->left; + t->left = temp->right; + temp->right = t; + t = temp; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + double_rotate_right ( + node*& t + ) + { + + node* temp = t; + t = t->left->right; + + temp->left->right = t->left; + t->left = temp->left; + + temp->left = t->right; + t->right = temp; + + if (t->balance < 0) + { + t->left->balance = 0; + t->right->balance = 1; + } + else if (t->balance > 0) + { + t->left->balance = -1; + t->right->balance = 0; + } + else + { + t->left->balance = 0; + t->right->balance = 0; + } + t->balance = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_1:: + double_rotate_left ( + node*& t + ) + { + node* temp = t; + t = t->right->left; + + temp->right->left = t->right; + t->right = temp->right; + + temp->right = t->left; + t->left = temp; + + if (t->balance < 0) + { + t->left->balance = 0; + t->right->balance = 1; + } + else if (t->balance > 0) + { + t->left->balance = -1; + t->right->balance = 0; + } + else + { + t->left->balance = 0; + t->right->balance = 0; + } + + t->balance = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_biggest_element_in_tree ( + node*& t, + domain& d, + range& r + ) + { + // make a reference to the current node so we don't have to dereference a + // pointer a bunch of times + node& tree = *t; + + // if the right tree is an empty tree + if ( tree.right == 0) + { + // swap nodes domain and range elements into d and r + exchange(d,tree.d); + exchange(r,tree.r); + + // plug hole left by removing this node + t = tree.left; + + // delete the node that was just removed + pool.deallocate(&tree); + + // return that the height of this part of the tree has decreased + return true; + } + else + { + + // keep going right + + // if remove made the tree one height shorter + if ( remove_biggest_element_in_tree(tree.right,d,r) ) + { + // if this caused the current tree to strink then report that + if ( tree.balance == 1) + { + --tree.balance; + return true; + } + else + { + --tree.balance; + return keep_node_balanced(t); + } + } + + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_least_element_in_tree ( + node*& t, + domain& d, + range& r + ) + { + // make a reference to the current node so we don't have to dereference a + // pointer a bunch of times + node& tree = *t; + + // if the left tree is an empty tree + if ( tree.left == 0) + { + // swap nodes domain and range elements into d and r + exchange(d,tree.d); + exchange(r,tree.r); + + // plug hole left by removing this node + t = tree.right; + + // delete the node that was just removed + pool.deallocate(&tree); + + // return that the height of this part of the tree has decreased + return true; + } + else + { + + // keep going left + + // if remove made the tree one height shorter + if ( remove_least_element_in_tree(tree.left,d,r) ) + { + // if this caused the current tree to strink then report that + if ( tree.balance == -1) + { + ++tree.balance; + return true; + } + else + { + ++tree.balance; + return keep_node_balanced(t); + } + } + + return false; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + add_to_tree ( + node*& t, + domain& d, + range& r + ) + { + + // if found place to add + if (t == 0) + { + // create a node to add new item into + t = pool.allocate(); + + // make a reference to the current node so we don't have to dereference a + // pointer a bunch of times + node& tree = *t; + + + // set left and right pointers to NULL to indicate that there are no + // left or right subtrees + tree.left = 0; + tree.right = 0; + tree.balance = 0; + + // put d and r into t + exchange(tree.d,d); + exchange(tree.r,r); + + // indicate that the height of this tree has increased + return true; + } + else // keep looking for a place to add the new item + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + signed char old_balance = tree.balance; + + // add the new item to whatever subtree it should go into + if (comp( d , tree.d) ) + tree.balance -= add_to_tree(tree.left,d,r); + else + tree.balance += add_to_tree(tree.right,d,r); + + + // if the tree was balanced to start with + if (old_balance == 0) + { + // if its not balanced anymore then it grew in height + if (tree.balance != 0) + return true; + else + return false; + } + else + { + // if the tree is now balanced then it didn't grow + if (tree.balance == 0) + { + return false; + } + else + { + // if the tree needs to be balanced + if (tree.balance != old_balance) + { + return !keep_node_balanced(t); + } + // if there has been no change in the heights + else + { + return false; + } + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + fix_stack ( + node* t, + unsigned char depth + ) + { + // if we found the node we were looking for + if (t == current_element) + { + stack_pos = depth; + return true; + } + else if (t == 0) + { + return false; + } + + if (!( comp(t->d , current_element->d))) + { + // go left + if (fix_stack(t->left,depth+1)) + { + stack[depth] = t; + return true; + } + } + if (!(comp(current_element->d , t->d))) + { + // go right + if (fix_stack(t->right,depth+1)) + { + stack[depth] = t; + return true; + } + } + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_current_element_from_tree ( + node*& t, + domain& d, + range& r, + unsigned long cur_stack_pos + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if we found the node we were looking for + if (t == current_element) + { + + // swap nodes domain and range elements into d_copy and r + exchange(d,tree.d); + exchange(r,tree.r); + + // if there is no left node + if (tree.left == 0) + { + // move the enumerator on to the next element before we mess with the + // tree + move_next(); + + // plug hole left by removing this node and free memory + t = tree.right; // plug hole with right subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height has changed + return true; + } + // if there is no right node + else if (tree.right == 0) + { + // move the enumerator on to the next element before we mess with the + // tree + move_next(); + + // plug hole left by removing this node and free memory + t = tree.left; // plug hole with left subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height of this tree has changed + return true; + } + // if there are both a left and right sub node + else + { + + // in this case the next current element is going to get swapped back + // into this t node. + current_element = t; + + // get an element that can replace the one being removed and do this + // if it made the right subtree shrink by one + if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) + { + // adjust the tree height + --tree.balance; + + // if the height of the current tree has dropped by one + if (tree.balance == 0) + { + return true; + } + else + { + return keep_node_balanced(t); + } + } + // else this remove did not effect the height of this tree + else + { + return false; + } + + } + + } + else if ( (cur_stack_pos < stack_pos && stack[cur_stack_pos] == tree.left) || + tree.left == current_element ) + { + // go left + if (tree.balance == -1) + { + int balance = tree.balance; + balance += remove_current_element_from_tree(tree.left,d,r,cur_stack_pos+1); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance += remove_current_element_from_tree(tree.left,d,r,cur_stack_pos+1); + tree.balance = balance; + return keep_node_balanced(t); + } + } + else if ( (cur_stack_pos < stack_pos && stack[cur_stack_pos] == tree.right) || + tree.right == current_element ) + { + // go right + if (tree.balance == 1) + { + int balance = tree.balance; + balance -= remove_current_element_from_tree(tree.right,d,r,cur_stack_pos+1); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance -= remove_current_element_from_tree(tree.right,d,r,cur_stack_pos+1); + tree.balance = balance; + return keep_node_balanced(t); + } + } + + // this return should never happen but do it anyway to suppress compiler warnings + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_from_tree ( + node*& t, + const domain& d, + domain& d_copy, + range& r + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if item is on the left + if (comp(d , tree.d)) + { + // if the left side of the tree has the greatest height + if (tree.balance == -1) + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d,d_copy,r); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d,d_copy,r); + tree.balance = balance; + return keep_node_balanced(t); + } + + } + // if item is on the right + else if (comp(tree.d , d)) + { + + // if the right side of the tree has the greatest height + if (tree.balance == 1) + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d,d_copy,r); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d,d_copy,r); + tree.balance = balance; + return keep_node_balanced(t); + } + } + // if item is found + else + { + + // swap nodes domain and range elements into d_copy and r + exchange(d_copy,tree.d); + exchange(r,tree.r); + + // if there is no left node + if (tree.left == 0) + { + + // plug hole left by removing this node and free memory + t = tree.right; // plug hole with right subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height has changed + return true; + } + // if there is no right node + else if (tree.right == 0) + { + + // plug hole left by removing this node and free memory + t = tree.left; // plug hole with left subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height of this tree has changed + return true; + } + // if there are both a left and right sub node + else + { + + // get an element that can replace the one being removed and do this + // if it made the right subtree shrink by one + if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) + { + // adjust the tree height + --tree.balance; + + // if the height of the current tree has dropped by one + if (tree.balance == 0) + { + return true; + } + else + { + return keep_node_balanced(t); + } + } + // else this remove did not effect the height of this tree + else + { + return false; + } + + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + remove_from_tree ( + node*& t, + const domain& d + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if item is on the left + if (comp(d , tree.d)) + { + // if the left side of the tree has the greatest height + if (tree.balance == -1) + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance += remove_from_tree(tree.left,d); + tree.balance = balance; + return keep_node_balanced(t); + } + + } + // if item is on the right + else if (comp(tree.d , d)) + { + + // if the right side of the tree has the greatest height + if (tree.balance == 1) + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d); + tree.balance = balance; + return !tree.balance; + } + else + { + int balance = tree.balance; + balance -= remove_from_tree(tree.right,d); + tree.balance = balance; + return keep_node_balanced(t); + } + } + // if item is found + else + { + + // if there is no left node + if (tree.left == 0) + { + + // plug hole left by removing this node and free memory + t = tree.right; // plug hole with right subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height has changed + return true; + } + // if there is no right node + else if (tree.right == 0) + { + + // plug hole left by removing this node and free memory + t = tree.left; // plug hole with left subtree + + // delete old node + pool.deallocate(&tree); + + // indicate that the height of this tree has changed + return true; + } + // if there are both a left and right sub node + else + { + + // get an element that can replace the one being removed and do this + // if it made the right subtree shrink by one + if (remove_least_element_in_tree(tree.right,tree.d,tree.r)) + { + // adjust the tree height + --tree.balance; + + // if the height of the current tree has dropped by one + if (tree.balance == 0) + { + return true; + } + else + { + return keep_node_balanced(t); + } + } + // else this remove did not effect the height of this tree + else + { + return false; + } + + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_1:: + return_reference ( + node* t, + const domain& d + ) + { + while (t != 0) + { + + if ( comp(d , t->d )) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_1:: + return_reference ( + const node* t, + const domain& d + ) const + { + while (t != 0) + { + + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_1:: + keep_node_balanced ( + node*& t + ) + { + // make a reference to the current node so we don't have to dereference + // a pointer a bunch of times + node& tree = *t; + + // if tree does not need to be balanced then return false + if (tree.balance == 0) + return false; + + + // if tree needs to be rotated left + if (tree.balance == 2) + { + if (tree.right->balance >= 0) + rotate_left(t); + else + double_rotate_left(t); + } + // else if the tree needs to be rotated right + else if (tree.balance == -2) + { + if (tree.left->balance <= 0) + rotate_right(t); + else + double_rotate_right(t); + } + + + if (t->balance == 0) + return true; + else + return false; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_1:: + get_count ( + const domain& d, + node* tree_root + ) const + { + if (tree_root != 0) + { + if (comp(d , tree_root->d)) + { + // go left + return get_count(d,tree_root->left); + } + else if (comp(tree_root->d , d)) + { + // go right + return get_count(d,tree_root->right); + } + else + { + // go left and right to look for more matches + return get_count(d,tree_root->left) + + get_count(d,tree_root->right) + + 1; + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_1_ + diff --git a/dlib/binary_search_tree/binary_search_tree_kernel_2.h b/dlib/binary_search_tree/binary_search_tree_kernel_2.h new file mode 100644 index 0000000000000000000000000000000000000000..83866f27188769e9fc0aff1e7a9a91cd906e17e1 --- /dev/null +++ b/dlib/binary_search_tree/binary_search_tree_kernel_2.h @@ -0,0 +1,1897 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_2_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_2_ + +#include "binary_search_tree_kernel_abstract.h" +#include "../algs.h" +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager, + typename compare = std::less + > + class binary_search_tree_kernel_2 : public enumerable >, + public asc_pair_remover + { + + /*! + INITIAL VALUE + NIL == pointer to a node that represents a leaf + tree_size == 0 + tree_root == NIL + at_start == true + current_element == 0 + + + CONVENTION + current_element_valid() == (current_element != 0) + if (current_element_valid()) then + element() == current_element->d and current_element->r + at_start_ == at_start() + + + tree_size == size() + + NIL == pointer to a node that represents a leaf + + if (tree_size != 0) + tree_root == pointer to the root node of the binary search tree + else + tree_root == NIL + + tree_root->color == black + Every leaf is black and all leafs are the NIL node. + The number of black nodes in any path from the root to a leaf is the + same. + + for all nodes: + { + - left points to the left subtree or NIL if there is no left subtree + - right points to the right subtree or NIL if there is no right + subtree + - parent points to the parent node or NIL if the node is the root + - ordering of nodes is determined by comparing each node's d member + - all elements in a left subtree are <= the node + - all elements in a right subtree are >= the node + - color == red or black + - if (color == red) + - the node's children are black + } + + !*/ + + class node + { + public: + node* left; + node* right; + node* parent; + domain d; + range r; + char color; + }; + + class mpair : public map_pair + { + public: + const domain* d; + range* r; + + const domain& key( + ) const { return *d; } + + const range& value( + ) const { return *r; } + + range& value( + ) { return *r; } + }; + + + const static char red = 0; + const static char black = 1; + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + binary_search_tree_kernel_2( + ) : + NIL(pool.allocate()), + tree_size(0), + tree_root(NIL), + current_element(0), + at_start_(true) + { + NIL->color = black; + NIL->left = 0; + NIL->right = 0; + NIL->parent = 0; + } + + virtual ~binary_search_tree_kernel_2( + ); + + inline void clear( + ); + + inline short height ( + ) const; + + inline unsigned long count ( + const domain& d + ) const; + + inline void add ( + domain& d, + range& r + ); + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& d + ); + + void remove_any ( + domain& d, + range& r + ); + + inline const range* operator[] ( + const domain& item + ) const; + + inline range* operator[] ( + const domain& item + ); + + inline void swap ( + binary_search_tree_kernel_2& item + ); + + // functions from the enumerable interface + inline size_t size ( + ) const; + + bool at_start ( + ) const; + + inline void reset ( + ) const; + + bool current_element_valid ( + ) const; + + const map_pair& element ( + ) const; + + map_pair& element ( + ); + + bool move_next ( + ) const; + + void remove_last_in_order ( + domain& d, + range& r + ); + + void remove_current_element ( + domain& d, + range& r + ); + + void position_enumerator ( + const domain& d + ) const; + + private: + + inline void rotate_left ( + node* t + ); + /*! + requires + - t != NIL + - t->right != NIL + ensures + - performs a left rotation around t and its right child + !*/ + + inline void rotate_right ( + node* t + ); + /*! + requires + - t != NIL + - t->left != NIL + ensures + - performs a right rotation around t and its left child + !*/ + + inline void double_rotate_right ( + node* t + ); + /*! + requires + - t != NIL + - t->left != NIL + - t->left->right != NIL + - double_rotate_right() is only called in fix_after_add() + ensures + - performs a left rotation around t->left + - then performs a right rotation around t + !*/ + + inline void double_rotate_left ( + node* t + ); + /*! + requires + - t != NIL + - t->right != NIL + - t->right->left != NIL + - double_rotate_left() is only called in fix_after_add() + ensures + - performs a right rotation around t->right + - then performs a left rotation around t + !*/ + + void remove_biggest_element_in_tree ( + node* t, + domain& d, + range& r + ); + /*! + requires + - t != NIL (i.e. there must be something in the tree to remove) + ensures + - the biggest node in t has been removed + - the biggest node element in t has been put into #d and #r + - #t is still a binary search tree + !*/ + + bool remove_least_element_in_tree ( + node* t, + domain& d, + range& r + ); + /*! + requires + - t != NIL (i.e. there must be something in the tree to remove) + ensures + - the least node in t has been removed + - the least node element in t has been put into #d and #r + - #t is still a binary search tree + - if (the node that was removed was the one pointed to by current_element) then + - returns true + - else + - returns false + !*/ + + void add_to_tree ( + node* t, + domain& d, + range& r + ); + /*! + requires + - t != NIL + ensures + - d and r are now in #t + - there is a mapping from d to r in #t + - #d and #r have initial values for their types + - #t is still a binary search tree + !*/ + + void remove_from_tree ( + node* t, + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - return_reference(t,d) != 0 + ensures + - #d_copy is equivalent to d + - the first element in t equivalent to d that is encountered when searching down the tree + from t has been removed and swapped into #d_copy. Also, the associated range element + has been removed and swapped into #r. + - if (the node that got removed wasn't current_element) then + - adjusts the current_element pointer if the data in the node that it points to gets moved. + - else + - the value of current_element is now invalid + - #t is still a binary search tree + !*/ + + void remove_from_tree ( + node* t, + const domain& d + ); + /*! + requires + - return_reference(t,d) != 0 + ensures + - an element in t equivalent to d has been removed + - #t is still a binary search tree + !*/ + + const range* return_reference ( + const node* t, + const domain& d + ) const; + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + range* return_reference ( + node* t, + const domain& d + ); + /*! + ensures + - if (there is a domain element equivalent to d in t) then + - returns a pointer to the element in the range equivalent to d + - else + - returns 0 + !*/ + + void fix_after_add ( + node* t + ); + /*! + requires + - t == pointer to the node just added + - t->color == red + - t->parent != NIL (t must not be the root) + - fix_after_add() is only called after a new node has been added + to t + ensures + - fixes any deviations from the CONVENTION caused by adding a node + !*/ + + void fix_after_remove ( + node* t + ); + /*! + requires + - t == pointer to the only child of the node that was spliced out + - fix_after_remove() is only called after a node has been removed + from t + - the color of the spliced out node was black + ensures + - fixes any deviations from the CONVENTION causes by removing a node + !*/ + + + short tree_height ( + node* t + ) const; + /*! + ensures + - returns the number of nodes in the longest path from the root of the + tree to a leaf + !*/ + + void delete_tree ( + node* t + ); + /*! + requires + - t == root of binary search tree + - t != NIL + ensures + - deletes all nodes in t except for NIL + !*/ + + unsigned long get_count ( + const domain& item, + node* tree_root + ) const; + /*! + requires + - tree_root == the root of a binary search tree or NIL + ensures + - if (tree_root == NIL) then + - returns 0 + - else + - returns the number of elements in tree_root that are + equivalent to item + !*/ + + + + // data members + typename mem_manager::template rebind::other pool; + node* NIL; + unsigned long tree_size; + node* tree_root; + mutable node* current_element; + mutable bool at_start_; + mutable mpair p; + compare comp; + + + + // restricted functions + binary_search_tree_kernel_2(binary_search_tree_kernel_2&); + binary_search_tree_kernel_2& operator=(binary_search_tree_kernel_2&); + + + }; + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + inline void swap ( + binary_search_tree_kernel_2& a, + binary_search_tree_kernel_2& b + ) { a.swap(b); } + + + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void deserialize ( + binary_search_tree_kernel_2& item, + std::istream& in + ) + { + try + { + item.clear(); + unsigned long size; + deserialize(size,in); + domain d; + range r; + for (unsigned long i = 0; i < size; ++i) + { + deserialize(d,in); + deserialize(r,in); + item.add(d,r); + } + } + catch (serialization_error& e) + { + item.clear(); + throw serialization_error(e.info + "\n while deserializing object of type binary_search_tree_kernel_2"); + } + } + + + + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + binary_search_tree_kernel_2:: + ~binary_search_tree_kernel_2 ( + ) + { + if (tree_root != NIL) + delete_tree(tree_root); + pool.deallocate(NIL); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + clear ( + ) + { + if (tree_size > 0) + { + delete_tree(tree_root); + tree_root = NIL; + tree_size = 0; + } + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + size_t binary_search_tree_kernel_2:: + size ( + ) const + { + return tree_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + short binary_search_tree_kernel_2:: + height ( + ) const + { + return tree_height(tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_2:: + count ( + const domain& item + ) const + { + return get_count(item,tree_root); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + add ( + domain& d, + range& r + ) + { + if (tree_size == 0) + { + tree_root = pool.allocate(); + tree_root->color = black; + tree_root->left = NIL; + tree_root->right = NIL; + tree_root->parent = NIL; + exchange(tree_root->d,d); + exchange(tree_root->r,r); + } + else + { + add_to_tree(tree_root,d,r); + } + ++tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + remove_from_tree(tree_root,d,d_copy,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + destroy ( + const domain& item + ) + { + remove_from_tree(tree_root,item); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_any ( + domain& d, + range& r + ) + { + remove_least_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_2:: + operator[] ( + const domain& d + ) + { + return return_reference(tree_root,d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_2:: + operator[] ( + const domain& d + ) const + { + return return_reference(tree_root,d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + swap ( + binary_search_tree_kernel_2& item + ) + { + pool.swap(item.pool); + + exchange(p,item.p); + exchange(comp,item.comp); + + node* tree_root_temp = item.tree_root; + unsigned long tree_size_temp = item.tree_size; + node* const NIL_temp = item.NIL; + node* current_element_temp = item.current_element; + bool at_start_temp = item.at_start_; + + item.tree_root = tree_root; + item.tree_size = tree_size; + item.NIL = NIL; + item.current_element = current_element; + item.at_start_ = at_start_; + + tree_root = tree_root_temp; + tree_size = tree_size_temp; + NIL = NIL_temp; + current_element = current_element_temp; + at_start_ = at_start_temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_last_in_order ( + domain& d, + range& r + ) + { + remove_biggest_element_in_tree(tree_root,d,r); + --tree_size; + // reset the enumerator + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_current_element ( + domain& d, + range& r + ) + { + node* t = current_element; + move_next(); + remove_from_tree(t,t->d,d,r); + --tree_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + position_enumerator ( + const domain& d + ) const + { + // clear the enumerator state and make sure the stack is empty + reset(); + at_start_ = false; + node* t = tree_root; + node* parent = NIL; + bool went_left = false; + while (t != NIL) + { + if ( comp(d , t->d )) + { + // if item is on the left then look in left + parent = t; + t = t->left; + went_left = true; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + parent = t; + t = t->right; + went_left = false; + } + else + { + current_element = t; + return; + } + } + + // if we didn't find any matches but there might be something after the + // d in this tree. + if (parent != NIL) + { + current_element = parent; + // if we went left from this node then this node is the next + // biggest. + if (went_left) + { + return; + } + else + { + move_next(); + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // enumerable function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + at_start ( + ) const + { + return at_start_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + reset ( + ) const + { + at_start_ = true; + current_element = 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + current_element_valid ( + ) const + { + return (current_element != 0); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const map_pair& binary_search_tree_kernel_2:: + element ( + ) const + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + map_pair& binary_search_tree_kernel_2:: + element ( + ) + { + p.d = &(current_element->d); + p.r = &(current_element->r); + return p; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + move_next ( + ) const + { + // if we haven't started iterating yet + if (at_start_) + { + at_start_ = false; + if (tree_size == 0) + { + return false; + } + else + { + // find the first element in the tree + current_element = tree_root; + node* temp = current_element->left; + while (temp != NIL) + { + current_element = temp; + temp = current_element->left; + } + return true; + } + } + else + { + if (current_element == 0) + { + return false; + } + else + { + bool went_up; // true if we went up the tree from a child node to parent + bool from_left = false; // true if we went up and were coming from a left child node + // find the next element in the tree + if (current_element->right != NIL) + { + // go right and down + current_element = current_element->right; + went_up = false; + } + else + { + went_up = true; + node* parent = current_element->parent; + if (parent == NIL) + { + // in this case we have iterated over all the element of the tree + current_element = 0; + return false; + } + + from_left = (parent->left == current_element); + // go up to parent + current_element = parent; + } + + + while (true) + { + if (went_up) + { + if (from_left) + { + // in this case we have found the next node + break; + } + else + { + // we should go up + node* parent = current_element->parent; + from_left = (parent->left == current_element); + current_element = parent; + if (current_element == NIL) + { + // in this case we have iterated over all the elements + // in the tree + current_element = 0; + return false; + } + } + } + else + { + // we just went down to a child node + if (current_element->left != NIL) + { + // go left + went_up = false; + current_element = current_element->left; + } + else + { + // if there is no left child then we have found the next node + break; + } + } + } + + return true; + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + delete_tree ( + node* t + ) + { + if (t->left != NIL) + delete_tree(t->left); + if (t->right != NIL) + delete_tree(t->right); + pool.deallocate(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + rotate_left ( + node* t + ) + { + + // perform the rotation + node* temp = t->right; + t->right = temp->left; + if (temp->left != NIL) + temp->left->parent = t; + temp->left = t; + temp->parent = t->parent; + + + if (t == tree_root) + tree_root = temp; + else + { + // if t was on the left + if (t->parent->left == t) + t->parent->left = temp; + else + t->parent->right = temp; + } + + t->parent = temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + rotate_right ( + node* t + ) + { + // perform the rotation + node* temp = t->left; + t->left = temp->right; + if (temp->right != NIL) + temp->right->parent = t; + temp->right = t; + temp->parent = t->parent; + + if (t == tree_root) + tree_root = temp; + else + { + // if t is a left child + if (t->parent->left == t) + t->parent->left = temp; + else + t->parent->right = temp; + } + + t->parent = temp; + } + + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + double_rotate_right ( + node* t + ) + { + + // preform the rotation + node& temp = *(t->left->right); + t->left = temp.right; + temp.right->parent = t; + temp.left->parent = temp.parent; + temp.parent->right = temp.left; + temp.parent->parent = &temp; + temp.right = t; + temp.left = temp.parent; + temp.parent = t->parent; + + + if (tree_root == t) + tree_root = &temp; + else + { + // t is a left child + if (t->parent->left == t) + t->parent->left = &temp; + else + t->parent->right = &temp; + } + t->parent = &temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + double_rotate_left ( + node* t + ) + { + + + // preform the rotation + node& temp = *(t->right->left); + t->right = temp.left; + temp.left->parent = t; + temp.right->parent = temp.parent; + temp.parent->left = temp.right; + temp.parent->parent = &temp; + temp.left = t; + temp.right = temp.parent; + temp.parent = t->parent; + + + if (tree_root == t) + tree_root = &temp; + else + { + // t is a left child + if (t->parent->left == t) + t->parent->left = &temp; + else + t->parent->right = &temp; + } + t->parent = &temp; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_biggest_element_in_tree ( + node* t, + domain& d, + range& r + ) + { + + node* next = t->right; + node* child; // the child node of the one we will slice out + + if (next == NIL) + { + // need to determine if t is a right or left child + if (t->parent->right == t) + child = t->parent->right = t->left; + else + child = t->parent->left = t->left; + + // update tree_root if necessary + if (t == tree_root) + tree_root = child; + } + else + { + // find the least node + do + { + t = next; + next = next->right; + } while (next != NIL); + // t is a right child + child = t->parent->right = t->left; + + } + + // swap the item from this node into d and r + exchange(d,t->d); + exchange(r,t->r); + + // plug hole right by removing this node + child->parent = t->parent; + + // keep the red-black properties true + if (t->color == black) + fix_after_remove(child); + + // free the memory for this removed node + pool.deallocate(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + bool binary_search_tree_kernel_2:: + remove_least_element_in_tree ( + node* t, + domain& d, + range& r + ) + { + + node* next = t->left; + node* child; // the child node of the one we will slice out + + if (next == NIL) + { + // need to determine if t is a left or right child + if (t->parent->left == t) + child = t->parent->left = t->right; + else + child = t->parent->right = t->right; + + // update tree_root if necessary + if (t == tree_root) + tree_root = child; + } + else + { + // find the least node + do + { + t = next; + next = next->left; + } while (next != NIL); + // t is a left child + child = t->parent->left = t->right; + + } + + // swap the item from this node into d and r + exchange(d,t->d); + exchange(r,t->r); + + // plug hole left by removing this node + child->parent = t->parent; + + // keep the red-black properties true + if (t->color == black) + fix_after_remove(child); + + bool rvalue = (t == current_element); + // free the memory for this removed node + pool.deallocate(t); + return rvalue; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + add_to_tree ( + node* t, + domain& d, + range& r + ) + { + // parent of the current node + node* parent; + + // find a place to add node + while (true) + { + parent = t; + // if item should be put on the left then go left + if (comp(d , t->d)) + { + t = t->left; + if (t == NIL) + { + t = parent->left = pool.allocate(); + break; + } + } + // if item should be put on the right then go right + else + { + t = t->right; + if (t == NIL) + { + t = parent->right = pool.allocate(); + break; + } + } + } + + // t is now the node where we will add item and + // parent is the parent of t + + t->parent = parent; + t->left = NIL; + t->right = NIL; + t->color = red; + exchange(t->d,d); + exchange(t->r,r); + + + // keep the red-black properties true + fix_after_add(t); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_from_tree ( + node* t, + const domain& d, + domain& d_copy, + range& r + ) + { + while (true) + { + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // found the node we want to remove + + // swap out the item into d_copy and r + exchange(d_copy,t->d); + exchange(r,t->r); + + if (t->left == NIL) + { + // if there is no left subtree + + node* parent = t->parent; + + // plug hole with right subtree + + + // if t is on the left + if (parent->left == t) + parent->left = t->right; + else + parent->right = t->right; + t->right->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->right; + + if (t->color == black) + fix_after_remove(t->right); + + // delete old node + pool.deallocate(t); + } + else if (t->right == NIL) + { + // if there is no right subtree + + node* parent = t->parent; + + // plug hole with left subtree + if (parent->left == t) + parent->left = t->left; + else + parent->right = t->left; + t->left->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->left; + + if (t->color == black) + fix_after_remove(t->left); + + // delete old node + pool.deallocate(t); + } + else + { + // if there is both a left and right subtree + // get an element to fill this node now that its been swapped into + // item_copy + if (remove_least_element_in_tree(t->right,t->d,t->r)) + { + // the node removed was the one pointed to by current_element so we + // need to update it so that it points to the right spot. + current_element = t; + } + } + + // quit loop + break; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + remove_from_tree ( + node* t, + const domain& d + ) + { + while (true) + { + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // found the node we want to remove + + + if (t->left == NIL) + { + // if there is no left subtree + + node* parent = t->parent; + + // plug hole with right subtree + + + if (parent->left == t) + parent->left = t->right; + else + parent->right = t->right; + t->right->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->right; + + if (t->color == black) + fix_after_remove(t->right); + + // delete old node + pool.deallocate(t); + } + else if (t->right == NIL) + { + // if there is no right subtree + + node* parent = t->parent; + + // plug hole with left subtree + if (parent->left == t) + parent->left = t->left; + else + parent->right = t->left; + t->left->parent = parent; + + // update tree_root if necessary + if (t == tree_root) + tree_root = t->left; + + if (t->color == black) + fix_after_remove(t->left); + + // delete old node + pool.deallocate(t); + } + else + { + // if there is both a left and right subtree + // get an element to fill this node now that its been swapped into + // item_copy + remove_least_element_in_tree(t->right,t->d,t->r); + + } + + // quit loop + break; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + range* binary_search_tree_kernel_2:: + return_reference ( + node* t, + const domain& d + ) + { + while (t != NIL) + { + if ( comp(d , t->d )) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + const range* binary_search_tree_kernel_2:: + return_reference ( + const node* t, + const domain& d + ) const + { + while (t != NIL) + { + if ( comp(d , t->d) ) + { + // if item is on the left then look in left + t = t->left; + } + else if (comp(t->d , d)) + { + // if item is on the right then look in right + t = t->right; + } + else + { + // if it's found then return a reference to it + return &(t->r); + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + fix_after_add ( + node* t + ) + { + + while (t->parent->color == red) + { + node& grandparent = *(t->parent->parent); + + // if both t's parent and its sibling are red + if (grandparent.left->color == grandparent.right->color) + { + grandparent.color = red; + grandparent.left->color = black; + grandparent.right->color = black; + t = &grandparent; + } + else + { + // if t is a left child + if (t == t->parent->left) + { + // if t's parent is a left child + if (t->parent == grandparent.left) + { + grandparent.color = red; + grandparent.left->color = black; + rotate_right(&grandparent); + } + // if t's parent is a right child + else + { + t->color = black; + grandparent.color = red; + double_rotate_left(&grandparent); + } + } + // if t is a right child + else + { + // if t's parent is a left child + if (t->parent == grandparent.left) + { + t->color = black; + grandparent.color = red; + double_rotate_right(&grandparent); + } + // if t's parent is a right child + else + { + grandparent.color = red; + grandparent.right->color = black; + rotate_left(&grandparent); + } + } + break; + } + } + tree_root->color = black; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void binary_search_tree_kernel_2:: + fix_after_remove ( + node* t + ) + { + + while (t != tree_root && t->color == black) + { + if (t->parent->left == t) + { + node* sibling = t->parent->right; + if (sibling->color == red) + { + sibling->color = black; + t->parent->color = red; + rotate_left(t->parent); + sibling = t->parent->right; + } + + if (sibling->left->color == black && sibling->right->color == black) + { + sibling->color = red; + t = t->parent; + } + else + { + if (sibling->right->color == black) + { + sibling->left->color = black; + sibling->color = red; + rotate_right(sibling); + sibling = t->parent->right; + } + + sibling->color = t->parent->color; + t->parent->color = black; + sibling->right->color = black; + rotate_left(t->parent); + t = tree_root; + + } + + + } + else + { + + node* sibling = t->parent->left; + if (sibling->color == red) + { + sibling->color = black; + t->parent->color = red; + rotate_right(t->parent); + sibling = t->parent->left; + } + + if (sibling->left->color == black && sibling->right->color == black) + { + sibling->color = red; + t = t->parent; + } + else + { + if (sibling->left->color == black) + { + sibling->right->color = black; + sibling->color = red; + rotate_left(sibling); + sibling = t->parent->left; + } + + sibling->color = t->parent->color; + t->parent->color = black; + sibling->left->color = black; + rotate_right(t->parent); + t = tree_root; + + } + + + } + + } + t->color = black; + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + short binary_search_tree_kernel_2:: + tree_height ( + node* t + ) const + { + if (t == NIL) + return 0; + + short height1 = tree_height(t->left); + short height2 = tree_height(t->right); + if (height1 > height2) + return height1 + 1; + else + return height2 + 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + unsigned long binary_search_tree_kernel_2:: + get_count ( + const domain& d, + node* tree_root + ) const + { + if (tree_root != NIL) + { + if (comp(d , tree_root->d)) + { + // go left + return get_count(d,tree_root->left); + } + else if (comp(tree_root->d , d)) + { + // go right + return get_count(d,tree_root->right); + } + else + { + // go left and right to look for more matches + return get_count(d,tree_root->left) + + get_count(d,tree_root->right) + + 1; + } + } + return 0; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_2_ + diff --git a/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h b/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..2abfe7e3955315cffbd90ef3ed897580ebfa73e4 --- /dev/null +++ b/dlib/binary_search_tree/binary_search_tree_kernel_abstract.h @@ -0,0 +1,311 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ +#ifdef DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ + +#include "../interfaces/map_pair.h" +#include "../interfaces/enumerable.h" +#include "../interfaces/remover.h" +#include "../serialize.h" +#include "../algs.h" +#include + +namespace dlib +{ + + template < + typename domain, + typename range, + typename mem_manager = default_memory_manager, + typename compare = std::less + > + class binary_search_tree : public enumerable >, + public asc_pair_remover + { + + /*! + REQUIREMENTS ON domain + domain must be comparable by compare where compare is a functor compatible with std::less and + domain is swappable by a global swap() and + domain must have a default constructor + + REQUIREMENTS ON range + range is swappable by a global swap() and + range must have a default constructor + + REQUIREMENTS ON mem_manager + must be an implementation of memory_manager/memory_manager_kernel_abstract.h or + must be an implementation of memory_manager_global/memory_manager_global_kernel_abstract.h or + must be an implementation of memory_manager_stateless/memory_manager_stateless_kernel_abstract.h + mem_manager::type can be set to anything. + + + POINTERS AND REFERENCES TO INTERNAL DATA + swap(), count(), height(), and operator[] functions + do not invalidate pointers or references to internal data. + + position_enumerator() invalidates pointers or references to + data returned by element() and only by element() (i.e. pointers and + references returned by operator[] are still valid). + + All other functions have no such guarantees. + + INITIAL VALUE + size() == 0 + height() == 0 + + ENUMERATION ORDER + The enumerator will iterate over the domain (and each associated + range element) elements in ascending order according to the compare functor. + (i.e. the elements are enumerated in sorted order) + + WHAT THIS OBJECT REPRESENTS + this object represents a data dictionary that is built on top of some + kind of binary search tree. It maps objects of type domain to objects + of type range. + + Also note that unless specified otherwise, no member functions + of this object throw exceptions. + + NOTE: + definition of equivalent: + a is equivalent to b if + a < b == false and + b < a == false + !*/ + + + public: + + typedef domain domain_type; + typedef range range_type; + typedef compare compare_type; + typedef mem_manager mem_manager_type; + + binary_search_tree( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + !*/ + + virtual ~binary_search_tree( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + short height ( + ) const; + /*! + ensures + - returns the number of elements in the longest path from the root + of the tree to a leaf + !*/ + + unsigned long count ( + const domain& d + ) const; + /*! + ensures + - returns the number of elements in the domain of *this that are + equivalent to d + !*/ + + void add ( + domain& d, + range& r + ); + /*! + requires + - &d != &r (i.e. d and r cannot be the same variable) + ensures + - adds a mapping between d and r to *this + - if (count(d) == 0) then + - #*(*this)[d] == r + - else + - #(*this)[d] != 0 + - #d and #r have initial values for their types + - #count(d) == count(d) + 1 + - #at_start() == true + - #size() == size() + 1 + throws + - std::bad_alloc or any exception thrown by domain's or range's + constructor. + if add() throws then it has no effect + !*/ + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + /*! + requires + - (*this)[d] != 0 + - &d != &r (i.e. d and r cannot be the same variable) + - &d != &d_copy (i.e. d and d_copy cannot be the same variable) + - &r != &d_copy (i.e. r and d_copy cannot be the same variable) + ensures + - some element in the domain of *this that is equivalent to d has + been removed and swapped into #d_copy. Additionally, its + associated range element has been removed and swapped into #r. + - #count(d) == count(d) - 1 + - #size() == size() - 1 + - #at_start() == true + !*/ + + void destroy ( + const domain& d + ); + /*! + requires + - (*this)[d] != 0 + ensures + - an element in the domain of *this equivalent to d has been removed. + The element in the range of *this associated with d has also been + removed. + - #count(d) == count(d) - 1 + - #size() == size() - 1 + - #at_start() == true + !*/ + + void remove_last_in_order ( + domain& d, + range& r + ); + /*! + requires + - size() > 0 + ensures + - the last/biggest (according to the compare functor) element in the domain of *this has + been removed and swapped into #d. The element in the range of *this + associated with #d has also been removed and swapped into #r. + - #count(#d) == count(#d) - 1 + - #size() == size() - 1 + - #at_start() == true + !*/ + + void remove_current_element ( + domain& d, + range& r + ); + /*! + requires + - current_element_valid() == true + ensures + - the current element given by element() has been removed and swapped into d and r. + - #d == element().key() + - #r == element().value() + - #count(#d) == count(#d) - 1 + - #size() == size() - 1 + - moves the enumerator to the next element. If element() was the last + element in enumeration order then #current_element_valid() == false + and #at_start() == false. + !*/ + + void position_enumerator ( + const domain& d + ) const; + /*! + ensures + - #at_start() == false + - if (count(d) > 0) then + - #element().key() == d + - else if (there are any items in the domain of *this that are bigger than + d according to the compare functor) then + - #element().key() == the smallest item in the domain of *this that is + bigger than d according to the compare functor. + - else + - #current_element_valid() == false + !*/ + + const range* operator[] ( + const domain& d + ) const; + /*! + ensures + - if (there is an element in the domain equivalent to d) then + - returns a pointer to an element in the range of *this that + is associated with an element in the domain of *this + equivalent to d. + - else + - returns 0 + !*/ + + range* operator[] ( + const domain& d + ); + /*! + ensures + - if (there is an element in the domain equivalent to d) then + - returns a pointer to an element in the range of *this that + is associated with an element in the domain of *this + equivalent to d. + - else + - returns 0 + !*/ + + void swap ( + binary_search_tree& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + binary_search_tree(binary_search_tree&); + binary_search_tree& operator=(binary_search_tree&); + + }; + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + inline void swap ( + binary_search_tree& a, + binary_search_tree& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + + template < + typename domain, + typename range, + typename mem_manager, + typename compare + > + void deserialize ( + binary_search_tree& item, + std::istream& in + ); + /*! + provides deserialization support + !*/ +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_ABSTRACT_ + diff --git a/dlib/binary_search_tree/binary_search_tree_kernel_c.h b/dlib/binary_search_tree/binary_search_tree_kernel_c.h new file mode 100644 index 0000000000000000000000000000000000000000..0dc15396193f699b484244697d23b75adb4aa425 --- /dev/null +++ b/dlib/binary_search_tree/binary_search_tree_kernel_c.h @@ -0,0 +1,235 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BINARY_SEARCH_TREE_KERNEl_C_ +#define DLIB_BINARY_SEARCH_TREE_KERNEl_C_ + +#include "../interfaces/map_pair.h" +#include "binary_search_tree_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename bst_base + > + class binary_search_tree_kernel_c : public bst_base + { + typedef typename bst_base::domain_type domain; + typedef typename bst_base::range_type range; + + public: + + binary_search_tree_kernel_c () {} + + void remove ( + const domain& d, + domain& d_copy, + range& r + ); + + void destroy ( + const domain& d + ); + + void add ( + domain& d, + range& r + ); + + void remove_any ( + domain& d, + range& r + ); + + const map_pair& element( + ) const + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst map_pair& binary_search_tree::element() const" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return bst_base::element(); + } + + map_pair& element( + ) + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tmap_pair& binary_search_tree::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + return bst_base::element(); + } + + void remove_last_in_order ( + domain& d, + range& r + ); + + void remove_current_element ( + domain& d, + range& r + ); + + + }; + + + template < + typename bst_base + > + inline void swap ( + binary_search_tree_kernel_c& a, + binary_search_tree_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + add ( + domain& d, + range& r + ) + { + DLIB_CASSERT( static_cast(&d) != static_cast(&r), + "\tvoid binary_search_tree::add" + << "\n\tyou can't call add() and give the same object to both parameters." + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&r: " << &r + << "\n\tsize(): " << this->size() + ); + + bst_base::add(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + destroy ( + const domain& d + ) + { + DLIB_CASSERT(this->operator[](d) != 0, + "\tvoid binary_search_tree::destroy" + << "\n\tthe element must be in the tree for it to be removed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + ); + + bst_base::destroy(d); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove ( + const domain& d, + domain& d_copy, + range& r + ) + { + DLIB_CASSERT(this->operator[](d) != 0 && + (static_cast(&d) != static_cast(&d_copy)) && + (static_cast(&d) != static_cast(&r)) && + (static_cast(&r) != static_cast(&d_copy)), + "\tvoid binary_search_tree::remove" + << "\n\tthe element must be in the tree for it to be removed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&d_copy: " << &d_copy + << "\n\t&r: " << &r + ); + + bst_base::remove(d,d_copy,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove_any( + domain& d, + range& r + ) + { + DLIB_CASSERT(this->size() != 0 && + (static_cast(&d) != static_cast(&r)), + "\tvoid binary_search_tree::remove_any" + << "\n\ttree must not be empty if something is going to be removed" + << "\n\tthis: " << this + << "\n\t&d: " << &d + << "\n\t&r: " << &r + ); + + bst_base::remove_any(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove_last_in_order ( + domain& d, + range& r + ) + { + DLIB_CASSERT(this->size() > 0, + "\tvoid binary_search_tree::remove_last_in_order()" + << "\n\tyou can't remove an element if it doesn't exist" + << "\n\tthis: " << this + ); + + bst_base::remove_last_in_order(d,r); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bst_base + > + void binary_search_tree_kernel_c:: + remove_current_element ( + domain& d, + range& r + ) + { + DLIB_CASSERT(this->current_element_valid() == true, + "\tvoid binary_search_tree::remove_current_element()" + << "\n\tyou can't remove the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + bst_base::remove_current_element(d,r); + } + + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BINARY_SEARCH_TREE_KERNEl_C_ + diff --git a/dlib/bit_stream.h b/dlib/bit_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..8885f35157e41d5e9c19c9204215a3763ac848e0 --- /dev/null +++ b/dlib/bit_stream.h @@ -0,0 +1,42 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAm_ +#define DLIB_BIT_STREAm_ + +#include "bit_stream/bit_stream_kernel_1.h" +#include "bit_stream/bit_stream_kernel_c.h" + +#include "bit_stream/bit_stream_multi_1.h" +#include "bit_stream/bit_stream_multi_c.h" + +namespace dlib +{ + + + class bit_stream + { + bit_stream() {} + public: + + //----------- kernels --------------- + + // kernel_1a + typedef bit_stream_kernel_1 + kernel_1a; + typedef bit_stream_kernel_c + kernel_1a_c; + + //---------- extensions ------------ + + + // multi_1 extend kernel_1a + typedef bit_stream_multi_1 + multi_1a; + typedef bit_stream_multi_c > + multi_1a_c; + + }; +} + +#endif // DLIB_BIT_STREAm_ + diff --git a/dlib/bit_stream/bit_stream_kernel_1.cpp b/dlib/bit_stream/bit_stream_kernel_1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f49db14d59bbd4fa3da223aabf2b67f2bc34490e --- /dev/null +++ b/dlib/bit_stream/bit_stream_kernel_1.cpp @@ -0,0 +1,200 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_KERNEL_1_CPp_ +#define DLIB_BIT_STREAM_KERNEL_1_CPp_ + + +#include "bit_stream_kernel_1.h" +#include "../algs.h" + +#include + +namespace dlib +{ + + inline void swap ( + bit_stream_kernel_1& a, + bit_stream_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + clear ( + ) + { + if (write_mode) + { + write_mode = false; + + // flush output buffer + if (buffer_size > 0) + { + buffer <<= 8 - buffer_size; + osp->write(reinterpret_cast(&buffer),1); + } + } + else + read_mode = false; + + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + set_input_stream ( + std::istream& is + ) + { + isp = &is; + read_mode = true; + + buffer_size = 0; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + set_output_stream ( + std::ostream& os + ) + { + osp = &os; + write_mode = true; + + buffer_size = 0; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + close ( + ) + { + if (write_mode) + { + write_mode = false; + + // flush output buffer + if (buffer_size > 0) + { + buffer <<= 8 - buffer_size; + osp->write(reinterpret_cast(&buffer),1); + } + } + else + read_mode = false; + } + +// ---------------------------------------------------------------------------------------- + + bool bit_stream_kernel_1:: + is_in_write_mode ( + ) const + { + return write_mode; + } + +// ---------------------------------------------------------------------------------------- + + bool bit_stream_kernel_1:: + is_in_read_mode ( + ) const + { + return read_mode; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + write ( + int bit + ) + { + // flush buffer if necessary + if (buffer_size == 8) + { + buffer <<= 8 - buffer_size; + if (osp->rdbuf()->sputn(reinterpret_cast(&buffer),1) == 0) + { + throw std::ios_base::failure("error occurred in the bit_stream object"); + } + + buffer_size = 0; + } + + ++buffer_size; + buffer <<= 1; + buffer += static_cast(bit); + } + +// ---------------------------------------------------------------------------------------- + + bool bit_stream_kernel_1:: + read ( + int& bit + ) + { + // get new byte if necessary + if (buffer_size == 0) + { + if (isp->rdbuf()->sgetn(reinterpret_cast(&buffer), 1) == 0) + { + // if we didn't read anything then return false + return false; + } + + buffer_size = 8; + } + + // put the most significant bit from buffer into bit + bit = static_cast(buffer >> 7); + + // shift out the bit that was just read + buffer <<= 1; + --buffer_size; + + return true; + } + +// ---------------------------------------------------------------------------------------- + + void bit_stream_kernel_1:: + swap ( + bit_stream_kernel_1& item + ) + { + + std::istream* isp_temp = item.isp; + std::ostream* osp_temp = item.osp; + bool write_mode_temp = item.write_mode; + bool read_mode_temp = item.read_mode; + unsigned char buffer_temp = item.buffer; + unsigned short buffer_size_temp = item.buffer_size; + + item.isp = isp; + item.osp = osp; + item.write_mode = write_mode; + item.read_mode = read_mode; + item.buffer = buffer; + item.buffer_size = buffer_size; + + + isp = isp_temp; + osp = osp_temp; + write_mode = write_mode_temp; + read_mode = read_mode_temp; + buffer = buffer_temp; + buffer_size = buffer_size_temp; + + } + +// ---------------------------------------------------------------------------------------- + +} +#endif // DLIB_BIT_STREAM_KERNEL_1_CPp_ + diff --git a/dlib/bit_stream/bit_stream_kernel_1.h b/dlib/bit_stream/bit_stream_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..801e93e0a27facbc6a00b59c9241452a4486d62a --- /dev/null +++ b/dlib/bit_stream/bit_stream_kernel_1.h @@ -0,0 +1,120 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_KERNEl_1_ +#define DLIB_BIT_STREAM_KERNEl_1_ + +#include "bit_stream_kernel_abstract.h" +#include + +namespace dlib +{ + + class bit_stream_kernel_1 + { + + /*! + INITIAL VALUE + write_mode == false + read_mode == false + + CONVENTION + write_mode == is_in_write_mode() + read_mode == is_in_read_mode() + + if (write_mode) + { + osp == pointer to an ostream object + buffer == the low order bits of buffer are the bits to be + written + buffer_size == the number of low order bits in buffer that are + bits that should be written + the lowest order bit is the last bit entered by the user + } + + if (read_mode) + { + isp == pointer to an istream object + buffer == the high order bits of buffer are the bits + waiting to be read by the user + buffer_size == the number of high order bits in buffer that + are bits that are waiting to be read + the highest order bit is the next bit to give to the user + } + !*/ + + + public: + + bit_stream_kernel_1 ( + ) : + write_mode(false), + read_mode(false) + {} + + virtual ~bit_stream_kernel_1 ( + ) + {} + + void clear ( + ); + + void set_input_stream ( + std::istream& is + ); + + void set_output_stream ( + std::ostream& os + ); + + void close ( + ); + + inline bool is_in_write_mode ( + ) const; + + inline bool is_in_read_mode ( + ) const; + + inline void write ( + int bit + ); + + bool read ( + int& bit + ); + + void swap ( + bit_stream_kernel_1& item + ); + + private: + + // member data + std::istream* isp; + std::ostream* osp; + bool write_mode; + bool read_mode; + unsigned char buffer; + unsigned short buffer_size; + + // restricted functions + bit_stream_kernel_1(bit_stream_kernel_1&); // copy constructor + bit_stream_kernel_1& operator=(bit_stream_kernel_1&); // assignment operator + + }; + + inline void swap ( + bit_stream_kernel_1& a, + bit_stream_kernel_1& b + ); + +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "bit_stream_kernel_1.cpp" +#endif + +#endif // DLIB_BIT_STREAM_KERNEl_1_ + diff --git a/dlib/bit_stream/bit_stream_kernel_abstract.h b/dlib/bit_stream/bit_stream_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..00c2ae3b94d2594170d14d48c8390a261c51c8fd --- /dev/null +++ b/dlib/bit_stream/bit_stream_kernel_abstract.h @@ -0,0 +1,185 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BIT_STREAM_KERNEl_ABSTRACT_ +#ifdef DLIB_BIT_STREAM_KERNEl_ABSTRACT_ + +#include + +namespace dlib +{ + + class bit_stream + { + + /*! + INITIAL VALUE + is_in_write_mode() == false + is_in_read_mode() == false + + WHAT THIS OBJECT REPRESENTS + this object is a middle man between a user and the iostream classes. + it allows single bits to be read/written easily to/from + the iostream classes + + BUFFERING: + This object will only read/write single bytes at a time from/to the + iostream objects. Any buffered bits still in the bit_stream object + when it is closed or destructed are lost if it is in read mode. If + it is in write mode then any remaining bits are guaranteed to be + written to the output stream by the time it is closed or destructed. + !*/ + + + public: + + bit_stream ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~bit_stream ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear ( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + if this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + + void set_input_stream ( + std::istream& is + ); + /*! + requires + - is_in_write_mode() == false + - is_in_read_mode() == false + - is is ready to give input + ensures + - #is_in_write_mode() == false + - #is_in_read_mode() == true + - #*this will now be reading from is + throws + - std::bad_alloc + !*/ + + void set_output_stream ( + std::ostream& os + ); + /*! + requires + - is_in_write_mode() == false + - is_in_read_mode() == false + - os is ready to take output + ensures + - #is_in_write_mode() == true + - #is_in_read_mode() == false + - #*this will now write to os + throws + - std::bad_alloc + !*/ + + + + void close ( + ); + /*! + requires + - is_in_write_mode() == true || is_in_read_mode() == true + ensures + - #is_in_write_mode() == false + - #is_in_read_mode() == false + !*/ + + bool is_in_write_mode ( + ) const; + /*! + ensures + - returns true if *this is associated with an output stream object + - returns false otherwise + !*/ + + bool is_in_read_mode ( + ) const; + /*! + ensures + - returns true if *this is associated with an input stream object + - returns false otherwise + !*/ + + void write ( + int bit + ); + /*! + requires + - is_in_write_mode() == true + - bit == 0 || bit == 1 + ensures + - bit will be written to the ostream object associated with *this + throws + - std::ios_base::failure + if (there was a problem writing to the output stream) then + this exception will be thrown. #*this will be unusable until + clear() is called and succeeds + - any other exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + bool read ( + int& bit + ); + /*! + requires + - is_in_read_mode() == true + ensures + - the next bit has been read and placed into #bit + - returns true if the read was successful, else false + (ex. false if EOF has been reached) + throws + - any exception + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + void swap ( + bit_stream& item + ); + /*! + ensures + - swaps *this and item + !*/ + + private: + + // restricted functions + bit_stream(bit_stream&); // copy constructor + bit_stream& operator=(bit_stream&); // assignment operator + + }; + + inline void swap ( + bit_stream& a, + bit_stream& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_BIT_STREAM_KERNEl_ABSTRACT_ + diff --git a/dlib/bit_stream/bit_stream_kernel_c.h b/dlib/bit_stream/bit_stream_kernel_c.h new file mode 100644 index 0000000000000000000000000000000000000000..1d52bff200f4aae532856cbe84e8f01208e2f7e1 --- /dev/null +++ b/dlib/bit_stream/bit_stream_kernel_c.h @@ -0,0 +1,172 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_KERNEl_C_ +#define DLIB_BIT_STREAM_KERNEl_C_ + +#include "bit_stream_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename bit_stream_base // implements bit_stream/bit_stream_kernel_abstract.h + > + class bit_stream_kernel_c : public bit_stream_base + { + public: + + + void set_input_stream ( + std::istream& is + ); + + void set_output_stream ( + std::ostream& os + ); + + void close ( + ); + + void write ( + int bit + ); + + bool read ( + int& bit + ); + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_kernel_c& a, + bit_stream_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + set_input_stream ( + std::istream& is + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == false ) && ( this->is_in_read_mode() == false ), + "\tvoid bit_stream::set_intput_stream" + << "\n\tbit_stream must not be in write or read mode" + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::set_input_stream(is); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + set_output_stream ( + std::ostream& os + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == false ) && ( this->is_in_read_mode() == false ), + "\tvoid bit_stream::set_output_stream" + << "\n\tbit_stream must not be in write or read mode" + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::set_output_stream(os); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + close ( + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == true ) || ( this->is_in_read_mode() == true ), + "\tvoid bit_stream::close" + << "\n\tyou can't close a bit_stream that isn't open" + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::close(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_kernel_c:: + write ( + int bit + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_write_mode() == true ) && ( bit == 0 || bit == 1 ), + "\tvoid bit_stream::write" + << "\n\tthe bit stream bust be in write mode and bit must be either 1 or 0" + << "\n\tis_in_write_mode() == " << this->is_in_write_mode() + << "\n\tbit == " << bit + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::write(bit); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + bool bit_stream_kernel_c:: + read ( + int& bit + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_read_mode() == true ), + "\tbool bit_stream::read" + << "\n\tyou can't read from a bit_stream that isn't in read mode" + << "\n\tthis: " << this + ); + + // call the real function + return bit_stream_base::read(bit); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIT_STREAM_KERNEl_C_ + diff --git a/dlib/bit_stream/bit_stream_multi_1.h b/dlib/bit_stream/bit_stream_multi_1.h new file mode 100644 index 0000000000000000000000000000000000000000..bf1cc0357d86606416ce105c5cbf82c3f2f32c82 --- /dev/null +++ b/dlib/bit_stream/bit_stream_multi_1.h @@ -0,0 +1,103 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_MULTi_1_ +#define DLIB_BIT_STREAM_MULTi_1_ + +#include "bit_stream_multi_abstract.h" + +namespace dlib +{ + template < + typename bit_stream_base + > + class bit_stream_multi_1 : public bit_stream_base + { + + public: + + void multi_write ( + unsigned long data, + int num_to_write + ); + + int multi_read ( + unsigned long& data, + int num_to_read + ); + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_multi_1& a, + bit_stream_multi_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_multi_1:: + multi_write ( + unsigned long data, + int num_to_write + ) + { + // move the first bit into the most significant position + data <<= 32 - num_to_write; + + for (int i = 0; i < num_to_write; ++i) + { + // write the first bit from data + this->write(static_cast(data >> 31)); + + // shift the next bit into position + data <<= 1; + + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + int bit_stream_multi_1:: + multi_read ( + unsigned long& data, + int num_to_read + ) + { + int bit, i; + data = 0; + for (i = 0; i < num_to_read; ++i) + { + + // get a bit + if (this->read(bit) == false) + break; + + // shift data to make room for this new bit + data <<= 1; + + // put bit into the least significant position in data + data += static_cast(bit); + + } + + return i; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIT_STREAM_MULTi_1_ + diff --git a/dlib/bit_stream/bit_stream_multi_abstract.h b/dlib/bit_stream/bit_stream_multi_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..061af94f49e15a239303b57e40bedae7329912f0 --- /dev/null +++ b/dlib/bit_stream/bit_stream_multi_abstract.h @@ -0,0 +1,77 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BIT_STREAM_MULTi_ABSTRACT_ +#ifdef DLIB_BIT_STREAM_MULTi_ABSTRACT_ + +#include "bit_stream_kernel_abstract.h" + +namespace dlib +{ + template < + typename bit_stream_base + > + class bit_stream_multi : public bit_stream_base + { + + /*! + REQUIREMENTS ON BIT_STREAM_BASE + it is an implementation of bit_stream/bit_stream_kernel_abstract.h + + + WHAT THIS EXTENSION DOES FOR BIT_STREAM + this gives a bit_stream object the ability to read/write multible bits + at a time + !*/ + + + public: + + void multi_write ( + unsigned long data, + int num_to_write + ); + /*! + requires + - is_in_write_mode() == true + - 0 <= num_to_write <= 32 + ensures + - num_to_write low order bits from data will be written to the ostream + - object associated with *this + example: if data is 10010 then the bits will be written in the + order 1,0,0,1,0 + !*/ + + + int multi_read ( + unsigned long& data, + int num_to_read + ); + /*! + requires + - is_in_read_mode() == true + - 0 <= num_to_read <= 32 + ensures + - tries to read num_to_read bits into the low order end of #data + example: if the incoming bits were 10010 then data would end + up with 10010 as its low order bits + - all of the bits in #data not filled in by multi_read() are zero + - returns the number of bits actually read into #data + !*/ + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_multi& a, + bit_stream_multi& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +} + +#endif // DLIB_BIT_STREAM_MULTi_ABSTRACT_ + diff --git a/dlib/bit_stream/bit_stream_multi_c.h b/dlib/bit_stream/bit_stream_multi_c.h new file mode 100644 index 0000000000000000000000000000000000000000..de80c63280a9168d4152bd6b9ccbabe9bc3ab1da --- /dev/null +++ b/dlib/bit_stream/bit_stream_multi_c.h @@ -0,0 +1,101 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BIT_STREAM_MULTi_C_ +#define DLIB_BIT_STREAM_MULTi_C_ + +#include "bit_stream_multi_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + template < + typename bit_stream_base // implements bit_stream/bit_stream_multi_abstract.h + > + class bit_stream_multi_c : public bit_stream_base + { + public: + + void multi_write ( + unsigned long data, + int num_to_write + ); + + int multi_read ( + unsigned long& data, + int num_to_read + ); + + }; + + template < + typename bit_stream_base + > + inline void swap ( + bit_stream_multi_c& a, + bit_stream_multi_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + void bit_stream_multi_c:: + multi_write ( + unsigned long data, + int num_to_write + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( (this->is_in_write_mode() == true) && (num_to_write >= 0 && num_to_write <=32), + "\tvoid bit_stream::write" + << "\n\tthe bit stream bust be in write mode and" + << "\n\tnum_to_write must be between 0 and 32 inclusive" + << "\n\tnum_to_write == " << num_to_write + << "\n\tis_in_write_mode() == " << this->is_in_write_mode() + << "\n\tthis: " << this + ); + + // call the real function + bit_stream_base::multi_write(data,num_to_write); + + } + +// ---------------------------------------------------------------------------------------- + + template < + typename bit_stream_base + > + int bit_stream_multi_c:: + multi_read ( + unsigned long& data, + int num_to_read + ) + { + + // make sure requires clause is not broken + DLIB_CASSERT(( this->is_in_read_mode() == true && ( num_to_read >= 0 && num_to_read <=32 ) ), + "\tvoid bit_stream::read" + << "\n\tyou can't read from a bit_stream that isn't in read mode and" + << "\n\tnum_to_read must be between 0 and 32 inclusive" + << "\n\tnum_to_read == " << num_to_read + << "\n\tis_in_read_mode() == " << this->is_in_read_mode() + << "\n\tthis: " << this + ); + + // call the real function + return bit_stream_base::multi_read(data,num_to_read); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BIT_STREAM_MULTi_C_ + diff --git a/dlib/bits/c++config.h b/dlib/bits/c++config.h new file mode 100644 index 0000000000000000000000000000000000000000..6139ba8238c1db00fe4f10178abcb881715a2ef9 --- /dev/null +++ b/dlib/bits/c++config.h @@ -0,0 +1 @@ +#include "../dlib_include_path_tutorial.txt" diff --git a/dlib/bound_function_pointer.h b/dlib/bound_function_pointer.h new file mode 100644 index 0000000000000000000000000000000000000000..a482919c6e7db5aebbafa0d4846d6c55218a1810 --- /dev/null +++ b/dlib/bound_function_pointer.h @@ -0,0 +1,10 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BOUND_FUNCTION_POINTEr_ +#define DLIB_BOUND_FUNCTION_POINTEr_ + +#include "bound_function_pointer/bound_function_pointer_kernel_1.h" + +#endif // DLIB_BOUND_FUNCTION_POINTEr_ + + diff --git a/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h b/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..2984fbc93cf881682940dbd02ccd24e478000fbf --- /dev/null +++ b/dlib/bound_function_pointer/bound_function_pointer_kernel_1.h @@ -0,0 +1,774 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ +#define DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ + +#include "../algs.h" +#include "../member_function_pointer.h" +#include "bound_function_pointer_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace bfp1_helpers + { + template struct strip { typedef T type; }; + template struct strip { typedef T type; }; + + // ------------------------------------------------------------------------------------ + + class bound_function_helper_base_base + { + public: + virtual ~bound_function_helper_base_base(){} + virtual void call() const = 0; + virtual bool is_set() const = 0; + virtual void clone(void* ptr) const = 0; + }; + + // ------------------------------------------------------------------------------------ + + template + class bound_function_helper_base : public bound_function_helper_base_base + { + public: + bound_function_helper_base():arg1(0), arg2(0), arg3(0), arg4(0) {} + + typename strip::type* arg1; + typename strip::type* arg2; + typename strip::type* arg3; + typename strip::type* arg4; + + + member_function_pointer mfp; + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1, *this->arg2, *this->arg3, *this->arg4); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3, *this->arg4); + else if (fp) fp(*this->arg1, *this->arg2, *this->arg3, *this->arg4); + } + + void (*fp)(T1, T2, T3, T4); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(); + } + + typename strip::type* fp; + }; + + template <> + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(); + else if (fp) fp(); + } + + void (*fp)(); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(*this->arg1); + else if (fp) fp(*this->arg1); + } + + void (*fp)(T1); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1, *this->arg2); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + if (this->mfp) this->mfp(*this->arg1, *this->arg2); + else if (fp) fp(*this->arg1, *this->arg2); + } + + void (*fp)(T1, T2); + }; + + // ---------------- + + template + class bound_function_helper : public bound_function_helper_base + { + public: + void call() const + { + (*fp)(*this->arg1, *this->arg2, *this->arg3); + } + + typename strip::type* fp; + }; + + template + class bound_function_helper : public bound_function_helper_base + { + public: + + void call() const + { + if (this->mfp) this->mfp(*this->arg1, *this->arg2, *this->arg3); + else if (fp) fp(*this->arg1, *this->arg2, *this->arg3); + } + + void (*fp)(T1, T2, T3); + }; + + // ------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------ + + template + class bound_function_helper_T : public T + { + public: + bound_function_helper_T(){ this->fp = 0;} + + bool is_set() const + { + return this->fp != 0 || this->mfp.is_set(); + } + + template + void safe_clone(stack_based_memory_block& buf) + { + // This is here just to validate the assumption that our block of memory we have made + // in bf_memory is the right size to store the data for this object. If you + // get a compiler error on this line then email me :) + COMPILE_TIME_ASSERT(sizeof(bound_function_helper_T) <= mem_size); + clone(buf.get()); + } + + void clone (void* ptr) const + { + bound_function_helper_T* p = new(ptr) bound_function_helper_T(); + p->arg1 = this->arg1; + p->arg2 = this->arg2; + p->arg3 = this->arg3; + p->arg4 = this->arg4; + p->fp = this->fp; + p->mfp = this->mfp; + } + }; + + } + +// ---------------------------------------------------------------------------------------- + + class bound_function_pointer + { + typedef bfp1_helpers::bound_function_helper_T > bf_null_type; + + public: + + // These typedefs are here for backwards compatibility with previous versions of + // dlib. + typedef bound_function_pointer kernel_1a; + typedef bound_function_pointer kernel_1a_c; + + + bound_function_pointer ( + ) { bf_null_type().safe_clone(bf_memory); } + + bound_function_pointer ( + const bound_function_pointer& item + ) { item.bf()->clone(bf_memory.get()); } + + ~bound_function_pointer() + { destroy_bf_memory(); } + + bound_function_pointer& operator= ( + const bound_function_pointer& item + ) { bound_function_pointer(item).swap(*this); return *this; } + + void clear ( + ) { bound_function_pointer().swap(*this); } + + bool is_set ( + ) const + { + return bf()->is_set(); + } + + void swap ( + bound_function_pointer& item + ) + { + // make a temp copy of item + bound_function_pointer temp(item); + + // destory the stuff in item + item.destroy_bf_memory(); + // copy *this into item + bf()->clone(item.bf_memory.get()); + + // destory the stuff in this + destroy_bf_memory(); + // copy temp into *this + temp.bf()->clone(bf_memory.get()); + } + + void operator() ( + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_set() == true , + "\tvoid bound_function_pointer::operator()" + << "\n\tYou must call set() before you can use this function" + << "\n\tthis: " << this + ); + + bf()->call(); + } + + private: + struct dummy{ void nonnull() {}}; + typedef void (dummy::*safe_bool)(); + + public: + operator safe_bool () const { return is_set() ? &dummy::nonnull : 0; } + bool operator!() const { return !is_set(); } + + // ------------------------------------------- + // set function object overloads + // ------------------------------------------- + + template + void set ( + F& function_object + ) + { + COMPILE_TIME_ASSERT(std::is_function::value == false); + COMPILE_TIME_ASSERT(std::is_pointer::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1 + ) + { + COMPILE_TIME_ASSERT(std::is_function::value == false); + COMPILE_TIME_ASSERT(std::is_pointer::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2 + ) + { + COMPILE_TIME_ASSERT(std::is_function::value == false); + COMPILE_TIME_ASSERT(std::is_pointer::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + COMPILE_TIME_ASSERT(std::is_function::value == false); + COMPILE_TIME_ASSERT(std::is_pointer::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + COMPILE_TIME_ASSERT(std::is_function::value == false); + COMPILE_TIME_ASSERT(std::is_pointer::value == false); + + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.fp = &function_object; + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + // set mfp overloads + // ------------------------------------------- + + template + void set ( + T& object, + void (T::*funct)() + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)()const + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + + template + void set ( + T& object, + void (T::*funct)(T1), + A1& arg1 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1)const, + A1& arg1 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ---------------- + + template + void set ( + T& object, + void (T::*funct)(T1, T2), + A1& arg1, + A2& arg2 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1, T2)const, + A1& arg1, + A2& arg2 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ---------------- + + template + void set ( + T& object, + void (T::*funct)(T1, T2, T3), + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1, T2, T3)const, + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ---------------- + + template + void set ( + T& object, + void (T::*funct)(T1, T2, T3, T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + template + void set ( + const T& object, + void (T::*funct)(T1, T2, T3, T4)const, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.mfp.set(object,funct); + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + // set fp overloads + // ------------------------------------------- + + void set ( + void (*funct)() + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1), + A1& arg1 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1, T2), + A1& arg1, + A2& arg2 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1, T2, T3), + A1& arg1, + A2& arg2, + A3& arg3 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + template + void set ( + void (*funct)(T1, T2, T3, T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ) + { + using namespace bfp1_helpers; + destroy_bf_memory(); + typedef bound_function_helper_T > bf_helper_type; + + bf_helper_type temp; + temp.arg1 = &arg1; + temp.arg2 = &arg2; + temp.arg3 = &arg3; + temp.arg4 = &arg4; + temp.fp = funct; + + temp.safe_clone(bf_memory); + } + + // ------------------------------------------- + + private: + + stack_based_memory_block bf_memory; + + void destroy_bf_memory ( + ) + { + // Honestly, this probably doesn't even do anything but I'm putting + // it here just for good measure. + bf()->~bound_function_helper_base_base(); + } + + bfp1_helpers::bound_function_helper_base_base* bf () + { return static_cast(bf_memory.get()); } + + const bfp1_helpers::bound_function_helper_base_base* bf () const + { return static_cast(bf_memory.get()); } + + }; + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + bound_function_pointer& a, + bound_function_pointer& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_1_ + diff --git a/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h b/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..b5356d6e0015601d2ee5791771221ccb64442a21 --- /dev/null +++ b/dlib/bound_function_pointer/bound_function_pointer_kernel_abstract.h @@ -0,0 +1,456 @@ +// Copyright (C) 2008 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ +#ifdef DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bound_function_pointer + { + /*! + INITIAL VALUE + is_set() == false + + WHAT THIS OBJECT REPRESENTS + This object represents a function with all its arguments bound to + specific objects. For example: + + void test(int& var) { var = var+1; } + + bound_function_pointer funct; + + int a = 4; + funct.set(test,a); // bind the variable a to the first argument of the test() function + + // at this point a == 4 + funct(); + // after funct() is called a == 5 + !*/ + + public: + + bound_function_pointer ( + ); + /*! + ensures + - #*this is properly initialized + !*/ + + bound_function_pointer( + const bound_function_pointer& item + ); + /*! + ensures + - *this == item + !*/ + + ~bound_function_pointer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + bound_function_pointer& operator=( + const bound_function_pointer& item + ); + /*! + ensures + - *this == item + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + !*/ + + bool is_set ( + ) const; + /*! + ensures + - if (this->set() has been called) then + - returns true + - else + - returns false + !*/ + + operator some_undefined_pointer_type ( + ) const; + /*! + ensures + - if (is_set()) then + - returns a non 0 value + - else + - returns a 0 value + !*/ + + bool operator! ( + ) const; + /*! + ensures + - returns !is_set() + !*/ + + void operator () ( + ) const; + /*! + requires + - is_set() == true + ensures + - calls the bound function on the object(s) specified by the last + call to this->set() + throws + - any exception thrown by the function specified by + the previous call to this->set(). + If any of these exceptions are thrown then the call to this + function will have no effect on *this. + !*/ + + void swap ( + bound_function_pointer& item + ); + /*! + ensures + - swaps *this and item + !*/ + + // ---------------------- + + template + void set ( + F& function_object + ); + /*! + requires + - function_object() is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object() + (This seems pointless but it is a useful base case) + !*/ + + template < typename T> + void set ( + T& object, + void (T::*funct)() + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)() + !*/ + + template < typename T> + void set ( + const T& object, + void (T::*funct)()const + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)() + !*/ + + void set ( + void (*funct)() + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct() + !*/ + + // ---------------------- + + template + void set ( + F& function_object, + A1& arg1 + ); + /*! + requires + - function_object(arg1) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1) + !*/ + + template < typename T, typename T1, typename A1 > + void set ( + T& object, + void (T::*funct)(T1), + A1& arg1 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1) + !*/ + + template < typename T, typename T1, typename A1 > + void set ( + const T& object, + void (T::*funct)(T1)const, + A1& arg1 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1) + !*/ + + template + void set ( + void (*funct)(T1), + A1& arg1 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1) + !*/ + + // ---------------------- + template + void set ( + F& function_object, + A1& arg1, + A2& arg2 + ); + /*! + requires + - function_object(arg1,arg2) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1,arg2) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2> + void set ( + T& object, + void (T::*funct)(T1,T2), + A1& arg1, + A2& arg2 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2> + void set ( + const T& object, + void (T::*funct)(T1,T2)const, + A1& arg1, + A2& arg2 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2) + !*/ + + template + void set ( + void (*funct)(T1,T2), + A1& arg1, + A2& arg2 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1,arg2) + !*/ + + // ---------------------- + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - function_object(arg1,arg2,arg3) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1,arg2,arg3) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3> + void set ( + T& object, + void (T::*funct)(T1,T2,T3), + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3> + void set ( + const T& object, + void (T::*funct)(T1,T2,T3)const, + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3) + !*/ + + template + void set ( + void (*funct)(T1,T2,T3), + A1& arg1, + A2& arg2, + A3& arg3 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1,arg2,arg3) + !*/ + + // ---------------------- + + template + void set ( + F& function_object, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - function_object(arg1,arg2,arg3,arg4) is a valid expression + ensures + - #is_set() == true + - calls to this->operator() will call function_object(arg1,arg2,arg3,arg4) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3, + typename T4, typename A4> + void set ( + T& object, + void (T::*funct)(T1,T2,T3,T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - funct == a valid member function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4) + !*/ + + template < typename T, typename T1, typename A1, + typename T2, typename A2, + typename T3, typename A3, + typename T4, typename A4> + void set ( + const T& object, + void (T::*funct)(T1,T2,T3,T4)const, + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - funct == a valid bound function pointer for class T + ensures + - #is_set() == true + - calls to this->operator() will call (object.*funct)(arg1,arg2,arg3,arg4) + !*/ + + template + void set ( + void (*funct)(T1,T2,T3,T4), + A1& arg1, + A2& arg2, + A3& arg3, + A4& arg4 + ); + /*! + requires + - funct == a valid function pointer + ensures + - #is_set() == true + - calls to this->operator() will call funct(arg1,arg2,arg3,arg4) + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + + inline void swap ( + bound_function_pointer& a, + bound_function_pointer& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOUND_FUNCTION_POINTER_KERNEl_ABSTRACT_ + diff --git a/dlib/bridge.h b/dlib/bridge.h new file mode 100644 index 0000000000000000000000000000000000000000..4b633c4053c11fca7f223b6a153c9fb1dc8daa6d --- /dev/null +++ b/dlib/bridge.h @@ -0,0 +1,17 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. + +#ifdef DLIB_ALL_SOURCE_END +#include "dlib_basic_cpp_build_tutorial.txt" +#endif + + +#ifndef DLIB_BRIdGE_ +#define DLIB_BRIdGE_ + + +#include "bridge/bridge.h" + +#endif // DLIB_BRIdGE_ + + diff --git a/dlib/bridge/bridge.h b/dlib/bridge/bridge.h new file mode 100644 index 0000000000000000000000000000000000000000..93e23995aff6172041134866b14d038993da91ec --- /dev/null +++ b/dlib/bridge/bridge.h @@ -0,0 +1,669 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BRIDGe_Hh_ +#define DLIB_BRIDGe_Hh_ + +#include +#include +#include + +#include "bridge_abstract.h" +#include "../pipe.h" +#include "../threads.h" +#include "../serialize.h" +#include "../sockets.h" +#include "../sockstreambuf.h" +#include "../logger.h" +#include "../algs.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct connect_to_ip_and_port + { + connect_to_ip_and_port ( + const std::string& ip_, + unsigned short port_ + ): ip(ip_), port(port_) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ip_address(ip) && port != 0, + "\t connect_to_ip_and_port()" + << "\n\t Invalid inputs were given to this function" + << "\n\t ip: " << ip + << "\n\t port: " << port + << "\n\t this: " << this + ); + } + + private: + friend class bridge; + const std::string ip; + const unsigned short port; + }; + + inline connect_to_ip_and_port connect_to ( + const network_address& addr + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(addr.port != 0, + "\t connect_to_ip_and_port()" + << "\n\t The TCP port to connect to can't be 0." + << "\n\t addr.port: " << addr.port + ); + + if (is_ip_address(addr.host_address)) + { + return connect_to_ip_and_port(addr.host_address, addr.port); + } + else + { + std::string ip; + if(hostname_to_ip(addr.host_address,ip)) + throw socket_error(ERESOLVE,"unable to resolve '" + addr.host_address + "' in connect_to()"); + + return connect_to_ip_and_port(ip, addr.port); + } + } + + struct listen_on_port + { + listen_on_port( + unsigned short port_ + ) : port(port_) + { + // make sure requires clause is not broken + DLIB_ASSERT( port != 0, + "\t listen_on_port()" + << "\n\t Invalid inputs were given to this function" + << "\n\t port: " << port + << "\n\t this: " << this + ); + } + + private: + friend class bridge; + const unsigned short port; + }; + + template + struct bridge_transmit_decoration + { + bridge_transmit_decoration ( + pipe_type& p_ + ) : p(p_) {} + + private: + friend class bridge; + pipe_type& p; + }; + + template + bridge_transmit_decoration transmit ( pipe_type& p) { return bridge_transmit_decoration(p); } + + template + struct bridge_receive_decoration + { + bridge_receive_decoration ( + pipe_type& p_ + ) : p(p_) {} + + private: + friend class bridge; + pipe_type& p; + }; + + template + bridge_receive_decoration receive ( pipe_type& p) { return bridge_receive_decoration(p); } + +// ---------------------------------------------------------------------------------------- + + struct bridge_status + { + bridge_status() : is_connected(false), foreign_port(0){} + + bool is_connected; + unsigned short foreign_port; + std::string foreign_ip; + }; + + inline void serialize ( const bridge_status& , std::ostream& ) + { + throw serialization_error("It is illegal to serialize bridge_status objects."); + } + + inline void deserialize ( bridge_status& , std::istream& ) + { + throw serialization_error("It is illegal to serialize bridge_status objects."); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl_brns + { + class impl_bridge_base + { + public: + + virtual ~impl_bridge_base() {} + + virtual bridge_status get_bridge_status ( + ) const = 0; + }; + + template < + typename transmit_pipe_type, + typename receive_pipe_type + > + class impl_bridge : public impl_bridge_base, private noncopyable, private multithreaded_object + { + /*! + CONVENTION + - if (list) then + - this object is supposed to be listening on the list object for incoming + connections when not connected. + - else + - this object is supposed to be attempting to connect to ip:port when + not connected. + + - get_bridge_status() == current_bs + !*/ + public: + + impl_bridge ( + unsigned short listen_port, + transmit_pipe_type* transmit_pipe_, + receive_pipe_type* receive_pipe_ + ) : + s(m), + receive_thread_active(false), + transmit_thread_active(false), + port(0), + transmit_pipe(transmit_pipe_), + receive_pipe(receive_pipe_), + dlog("dlib.bridge"), + keepalive_code(0), + message_code(1) + { + int status = create_listener(list, listen_port); + if (status == PORTINUSE) + { + std::ostringstream sout; + sout << "Error, the port " << listen_port << " is already in use."; + throw socket_error(EPORT_IN_USE, sout.str()); + } + else if (status == OTHER_ERROR) + { + throw socket_error("Unable to create listening socket for an unknown reason."); + } + + register_thread(*this, &impl_bridge::transmit_thread); + register_thread(*this, &impl_bridge::receive_thread); + register_thread(*this, &impl_bridge::connect_thread); + + start(); + } + + impl_bridge ( + const std::string ip_, + unsigned short port_, + transmit_pipe_type* transmit_pipe_, + receive_pipe_type* receive_pipe_ + ) : + s(m), + receive_thread_active(false), + transmit_thread_active(false), + port(port_), + ip(ip_), + transmit_pipe(transmit_pipe_), + receive_pipe(receive_pipe_), + dlog("dlib.bridge"), + keepalive_code(0), + message_code(1) + { + register_thread(*this, &impl_bridge::transmit_thread); + register_thread(*this, &impl_bridge::receive_thread); + register_thread(*this, &impl_bridge::connect_thread); + + start(); + } + + ~impl_bridge() + { + // tell the threads to terminate + stop(); + + // save current pipe enabled status so we can restore it to however + // it was before this destructor ran. + bool transmit_enabled = true; + bool receive_enabled = true; + + // make any calls blocked on a pipe return immediately. + if (transmit_pipe) + { + transmit_enabled = transmit_pipe->is_dequeue_enabled(); + transmit_pipe->disable_dequeue(); + } + if (receive_pipe) + { + receive_enabled = receive_pipe->is_enqueue_enabled(); + receive_pipe->disable_enqueue(); + } + + { + auto_mutex lock(m); + s.broadcast(); + // Shutdown the connection if we have one. This will cause + // all blocked I/O calls to return an error. + if (con) + con->shutdown(); + } + + // wait for all the threads to terminate. + wait(); + + if (transmit_pipe && transmit_enabled) + transmit_pipe->enable_dequeue(); + if (receive_pipe && receive_enabled) + receive_pipe->enable_enqueue(); + } + + bridge_status get_bridge_status ( + ) const + { + auto_mutex lock(current_bs_mutex); + return current_bs; + } + + private: + + + template + std::enable_if_t::value> enqueue_bridge_status ( + pipe_type* p, + const bridge_status& status + ) + { + if (p) + { + typename pipe_type::type temp(status); + p->enqueue(temp); + } + } + + template + std::enable_if_t::value> enqueue_bridge_status ( + pipe_type* , + const bridge_status& + ) + { + } + + void connect_thread ( + ) + { + while (!should_stop()) + { + auto_mutex lock(m); + int status = OTHER_ERROR; + if (list) + { + do + { + status = list->accept(con, 1000); + } while (status == TIMEOUT && !should_stop()); + } + else + { + status = create_connection(con, port, ip); + } + + if (should_stop()) + break; + + if (status != 0) + { + // The last connection attempt failed. So pause for a little bit before making another attempt. + s.wait_or_timeout(2000); + continue; + } + + dlog << LINFO << "Established new connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << "."; + + bridge_status temp_bs; + { auto_mutex lock(current_bs_mutex); + current_bs.is_connected = true; + current_bs.foreign_port = con->get_foreign_port(); + current_bs.foreign_ip = con->get_foreign_ip(); + temp_bs = current_bs; + } + enqueue_bridge_status(receive_pipe, temp_bs); + + + receive_thread_active = true; + transmit_thread_active = true; + + s.broadcast(); + + // Wait for the transmit and receive threads to end before we continue. + // This way we don't invalidate the con pointer while it is in use. + while (receive_thread_active || transmit_thread_active) + s.wait(); + + + dlog << LINFO << "Closed connection to " << con->get_foreign_ip() << ":" << con->get_foreign_port() << "."; + { auto_mutex lock(current_bs_mutex); + current_bs.is_connected = false; + current_bs.foreign_port = con->get_foreign_port(); + current_bs.foreign_ip = con->get_foreign_ip(); + temp_bs = current_bs; + } + enqueue_bridge_status(receive_pipe, temp_bs); + } + + } + + + void receive_thread ( + ) + { + while (true) + { + // wait until we have a connection + { auto_mutex lock(m); + while (!receive_thread_active && !should_stop()) + { + s.wait(); + } + + if (should_stop()) + break; + } + + + + try + { + if (receive_pipe) + { + sockstreambuf buf(con); + std::istream in(&buf); + typename receive_pipe_type::type item; + // This isn't necessary but doing it avoids a warning about + // item being uninitialized sometimes. + assign_zero_if_built_in_scalar_type(item); + + while (in.peek() != EOF) + { + unsigned char code; + in.read((char*)&code, sizeof(code)); + if (code == message_code) + { + deserialize(item, in); + receive_pipe->enqueue(item); + } + } + } + else + { + // Since we don't have a receive pipe to put messages into we will + // just read the bytes from the connection and ignore them. + char buf[1000]; + while (con->read(buf, sizeof(buf)) > 0) ; + } + } + catch (std::bad_alloc& ) + { + dlog << LERROR << "std::bad_alloc thrown while deserializing message from " + << con->get_foreign_ip() << ":" << con->get_foreign_port(); + } + catch (dlib::serialization_error& e) + { + dlog << LERROR << "dlib::serialization_error thrown while deserializing message from " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + catch (std::exception& e) + { + dlog << LERROR << "std::exception thrown while deserializing message from " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + + + + + con->shutdown(); + auto_mutex lock(m); + receive_thread_active = false; + s.broadcast(); + } + + auto_mutex lock(m); + receive_thread_active = false; + s.broadcast(); + } + + void transmit_thread ( + ) + { + while (true) + { + // wait until we have a connection + { auto_mutex lock(m); + while (!transmit_thread_active && !should_stop()) + { + s.wait(); + } + + if (should_stop()) + break; + } + + + + try + { + sockstreambuf buf(con); + std::ostream out(&buf); + typename transmit_pipe_type::type item; + // This isn't necessary but doing it avoids a warning about + // item being uninitialized sometimes. + assign_zero_if_built_in_scalar_type(item); + + + while (out) + { + bool dequeue_timed_out = false; + if (transmit_pipe ) + { + if (transmit_pipe->dequeue_or_timeout(item,1000)) + { + out.write((char*)&message_code, sizeof(message_code)); + serialize(item, out); + if (transmit_pipe->size() == 0) + out.flush(); + + continue; + } + + dequeue_timed_out = (transmit_pipe->is_enabled() && transmit_pipe->is_dequeue_enabled()); + } + + // Pause for about a second. Note that we use a wait_or_timeout() call rather + // than sleep() here because we want to wake up immediately if this object is + // being destructed rather than hang for a second. + if (!dequeue_timed_out) + { + auto_mutex lock(m); + if (should_stop()) + break; + + s.wait_or_timeout(1000); + } + // Just send the keepalive byte periodically so we can + // tell if the connection is alive. + out.write((char*)&keepalive_code, sizeof(keepalive_code)); + out.flush(); + } + } + catch (std::bad_alloc& ) + { + dlog << LERROR << "std::bad_alloc thrown while serializing message to " + << con->get_foreign_ip() << ":" << con->get_foreign_port(); + } + catch (dlib::serialization_error& e) + { + dlog << LERROR << "dlib::serialization_error thrown while serializing message to " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + catch (std::exception& e) + { + dlog << LERROR << "std::exception thrown while serializing message to " + << con->get_foreign_ip() << ":" << con->get_foreign_port() + << ".\nThe exception error message is: \n" << e.what(); + } + + + + + con->shutdown(); + auto_mutex lock(m); + transmit_thread_active = false; + s.broadcast(); + } + + auto_mutex lock(m); + transmit_thread_active = false; + s.broadcast(); + } + + mutex m; + signaler s; + bool receive_thread_active; + bool transmit_thread_active; + std::unique_ptr con; + std::unique_ptr list; + const unsigned short port; + const std::string ip; + transmit_pipe_type* const transmit_pipe; + receive_pipe_type* const receive_pipe; + logger dlog; + const unsigned char keepalive_code; + const unsigned char message_code; + + mutex current_bs_mutex; + bridge_status current_bs; + }; + } + + +// ---------------------------------------------------------------------------------------- + + class bridge : noncopyable + { + public: + + bridge () {} + + template < typename T, typename U, typename V > + bridge ( + T network_parameters, + U pipe1, + V pipe2 + ) { reconfigure(network_parameters,pipe1,pipe2); } + + template < typename T, typename U> + bridge ( + T network_parameters, + U pipe + ) { reconfigure(network_parameters,pipe); } + + + void clear ( + ) + { + pimpl.reset(); + } + + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename T > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, &transmit_pipe.p, 0)); } + + template < typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.port, 0, &receive_pipe.p)); } + + + + + template < typename T, typename R > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename T, typename R > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, &receive_pipe.p)); } + + template < typename R > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, 0, &receive_pipe.p)); } + + template < typename T > + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe + ) { pimpl.reset(); pimpl.reset(new impl_brns::impl_bridge(network_parameters.ip, network_parameters.port, &transmit_pipe.p, 0)); } + + + bridge_status get_bridge_status ( + ) const + { + if (pimpl) + return pimpl->get_bridge_status(); + else + return bridge_status(); + } + + private: + + std::unique_ptr pimpl; + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BRIDGe_Hh_ + diff --git a/dlib/bridge/bridge_abstract.h b/dlib/bridge/bridge_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..76ed21153af9119a82d265b40cc4926046362c94 --- /dev/null +++ b/dlib/bridge/bridge_abstract.h @@ -0,0 +1,347 @@ +// Copyright (C) 2011 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BRIDGe_ABSTRACT_ +#ifdef DLIB_BRIDGe_ABSTRACT_ + +#include +#include "../pipe/pipe_kernel_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + struct connect_to_ip_and_port + { + connect_to_ip_and_port ( + const std::string& ip, + unsigned short port + ); + /*! + requires + - is_ip_address(ip) == true + - port != 0 + ensures + - this object will represent a request to make a TCP connection + to the given IP address and port number. + !*/ + }; + + connect_to_ip_and_port connect_to ( + const network_address& addr + ); + /*! + requires + - addr.port != 0 + ensures + - converts the given network_address object into a connect_to_ip_and_port + object. + !*/ + + struct listen_on_port + { + listen_on_port( + unsigned short port + ); + /*! + requires + - port != 0 + ensures + - this object will represent a request to listen on the given + port number for incoming TCP connections. + !*/ + }; + + template < + typename pipe_type + > + bridge_transmit_decoration transmit ( + pipe_type& p + ); + /*! + requires + - pipe_type is some kind of dlib::pipe object + - the objects in the pipe must be serializable + ensures + - Adds a type decoration to the given pipe, marking it as a transmit pipe, and + then returns it. + !*/ + + template < + typename pipe_type + > + bridge_receive_decoration receive ( + pipe_type& p + ); + /*! + requires + - pipe_type is some kind of dlib::pipe object + - the objects in the pipe must be serializable + ensures + - Adds a type decoration to the given pipe, marking it as a receive pipe, and + then returns it. + !*/ + +// ---------------------------------------------------------------------------------------- + + struct bridge_status + { + /*! + WHAT THIS OBJECT REPRESENTS + This simple struct represents the state of a bridge object. A + bridge is either connected or not. If it is connected then it + is connected to a foreign host with an IP address and port number + as indicated by this object. + !*/ + + bridge_status( + ); + /*! + ensures + - #is_connected == false + - #foreign_port == 0 + - #foreign_ip == "" + !*/ + + bool is_connected; + unsigned short foreign_port; + std::string foreign_ip; + }; + +// ---------------------------------------------------------------------------------------- + + class bridge : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for bridging a dlib::pipe object between + two network connected applications. + + + Note also that this object contains a dlib::logger object + which will log various events taking place inside a bridge. + If you want to see these log messages then enable the logger + named "dlib.bridge". + + + BRIDGE PROTOCOL DETAILS + The bridge object creates a single TCP connection between + two applications. Whenever it sends an object from a pipe + over a TCP connection it sends a byte with the value 1 followed + immediately by the serialized copy of the object from the pipe. + The serialization is performed by calling the global serialize() + function. + + Additionally, a bridge object will periodically send bytes with + a value of 0 to ensure the TCP connection remains alive. These + are just read and ignored. + !*/ + + public: + + bridge ( + ); + /*! + ensures + - this object is properly initialized + - #get_bridge_status().is_connected == false + !*/ + + template + bridge ( + T network_parameters, + U pipe1, + V pipe2 + ); + /*! + requires + - T is of type connect_to_ip_and_port or listen_on_port + - U and V are of type bridge_transmit_decoration or bridge_receive_decoration, + however, U and V must be of different types (i.e. one is a receive type and + another a transmit type). + ensures + - this object is properly initialized + - performs: reconfigure(network_parameters, pipe1, pipe2) + (i.e. using this constructor is identical to using the default constructor + and then calling reconfigure()) + !*/ + + template + bridge ( + T network_parameters, + U pipe + ); + /*! + requires + - T is of type connect_to_ip_and_port or listen_on_port + - U is of type bridge_transmit_decoration or bridge_receive_decoration. + ensures + - this object is properly initialized + - performs: reconfigure(network_parameters, pipe) + (i.e. using this constructor is identical to using the default constructor + and then calling reconfigure()) + !*/ + + ~bridge ( + ); + /*! + ensures + - blocks until all resources associated with this object have been destroyed. + !*/ + + void clear ( + ); + /*! + ensures + - returns this object to its default constructed state. That is, it will + be inactive, neither maintaining a connection nor attempting to acquire one. + - Any active connections or listening sockets will be closed. + !*/ + + bridge_status get_bridge_status ( + ) const; + /*! + ensures + - returns the current status of this bridge object. In particular, returns + an object BS such that: + - BS.is_connected == true if and only if the bridge has an active TCP + connection to another computer. + - if (BS.is_connected) then + - BS.foreign_ip == the IP address of the remote host we are connected to. + - BS.foreign_port == the port number on the remote host we are connected to. + - else if (the bridge has previously been connected to a remote host but hasn't been + reconfigured or cleared since) then + - BS.foreign_ip == the IP address of the remote host we were connected to. + - BS.foreign_port == the port number on the remote host we were connected to. + - else + - BS.foreign_ip == "" + - BS.foreign_port == 0 + !*/ + + + + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This object will begin listening on the port specified by network_parameters + for incoming TCP connections. Any previous bridge state is cleared out. + - Onces a connection is established we will: + - Stop accepting new connections. + - Begin dequeuing objects from the transmit pipe and serializing them over + the TCP connection. + - Begin deserializing objects from the TCP connection and enqueueing them + onto the receive pipe. + - if (the current TCP connection is lost) then + - This object goes back to listening for a new connection. + - if (the receive pipe can contain bridge_status objects) then + - Whenever the bridge's status changes the updated bridge_status will be + enqueued onto the receive pipe unless the change was a TCP disconnect + resulting from a user calling reconfigure(), clear(), or destructing this + bridge. The status contents are defined by get_bridge_status(). + throws + - socket_error + This exception is thrown if we are unable to open the listening socket. + !*/ + template < typename T, typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - performs reconfigure(network_parameters, transmit_pipe, receive_pipe) + !*/ + template < typename T > + void reconfigure ( + listen_on_port network_parameters, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - This function is identical to the above two reconfigure() functions + except that there is no receive pipe. + !*/ + template < typename R > + void reconfigure ( + listen_on_port network_parameters, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This function is identical to the above three reconfigure() functions + except that there is no transmit pipe. + !*/ + + + + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This object will begin making TCP connection attempts to the IP address and port + specified by network_parameters. Any previous bridge state is cleared out. + - Onces a connection is established we will: + - Stop attempting new connections. + - Begin dequeuing objects from the transmit pipe and serializing them over + the TCP connection. + - Begin deserializing objects from the TCP connection and enqueueing them + onto the receive pipe. + - if (the current TCP connection is lost) then + - This object goes back to attempting to make a TCP connection with the + IP address and port specified by network_parameters. + - if (the receive pipe can contain bridge_status objects) then + - Whenever the bridge's status changes the updated bridge_status will be + enqueued onto the receive pipe unless the change was a TCP disconnect + resulting from a user calling reconfigure(), clear(), or destructing this + bridge. The status contents are defined by get_bridge_status(). + !*/ + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - performs reconfigure(network_parameters, transmit_pipe, receive_pipe) + !*/ + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_transmit_decoration transmit_pipe + ); + /*! + ensures + - This function is identical to the above two reconfigure() functions + except that there is no receive pipe. + !*/ + template + void reconfigure ( + connect_to_ip_and_port network_parameters, + bridge_receive_decoration receive_pipe + ); + /*! + ensures + - This function is identical to the above three reconfigure() functions + except that there is no transmit pipe. + !*/ + + }; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BRIDGe_ABSTRACT_ + + diff --git a/dlib/bsp.h b/dlib/bsp.h new file mode 100644 index 0000000000000000000000000000000000000000..899b6a40517ae4b79318f231afd0a3eb70f5d1aa --- /dev/null +++ b/dlib/bsp.h @@ -0,0 +1,12 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BSPh_ +#define DLIB_BSPh_ + + +#include "bsp/bsp.h" + +#endif // DLIB_BSPh_ + + + diff --git a/dlib/bsp/bsp.cpp b/dlib/bsp/bsp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..32e23519e7ed0193932a8ce6d2c5d6ed1cfa1372 --- /dev/null +++ b/dlib/bsp/bsp.cpp @@ -0,0 +1,496 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BSP_CPph_ +#define DLIB_BSP_CPph_ + +#include "bsp.h" +#include +#include + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +namespace dlib +{ + + namespace impl1 + { + + void connect_all ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id + ) + { + cons.clear(); + for (unsigned long i = 0; i < hosts.size(); ++i) + { + std::unique_ptr con(new bsp_con(hosts[i])); + dlib::serialize(node_id, con->stream); // tell the other end our node_id + unsigned long id = i+1; + cons.add(id, con); + } + } + + void connect_all_hostinfo ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id, + std::string& error_string + ) + { + cons.clear(); + for (unsigned long i = 0; i < hosts.size(); ++i) + { + try + { + std::unique_ptr con(new bsp_con(hosts[i].addr)); + dlib::serialize(node_id, con->stream); // tell the other end our node_id + con->stream.flush(); + unsigned long id = hosts[i].node_id; + cons.add(id, con); + } + catch (std::exception&) + { + std::ostringstream sout; + sout << "Could not connect to " << hosts[i].addr; + error_string = sout.str(); + break; + } + } + } + + + void send_out_connection_orders ( + map_id_to_con& cons, + const std::vector& hosts + ) + { + // tell everyone their node ids + cons.reset(); + while (cons.move_next()) + { + dlib::serialize(cons.element().key(), cons.element().value()->stream); + } + + // now tell them who to connect to + std::vector targets; + for (unsigned long i = 0; i < hosts.size(); ++i) + { + hostinfo info(hosts[i], i+1); + + dlib::serialize(targets, cons[info.node_id]->stream); + targets.push_back(info); + + // let the other host know how many incoming connections to expect + const unsigned long num = hosts.size()-targets.size(); + dlib::serialize(num, cons[info.node_id]->stream); + cons[info.node_id]->stream.flush(); + } + } + + // ------------------------------------------------------------------------------------ + + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + namespace impl2 + { + // These control bytes are sent before each message between nodes. Note that many + // of these are only sent between the control node (node 0) and the other nodes. + // This is because the controller node is responsible for handling the + // synchronization that needs to happen when all nodes block on calls to + // receive_data() + // at the same time. + + // denotes a normal content message. + const static char MESSAGE_HEADER = 0; + + // sent to the controller node when someone receives a message via receive_data(). + const static char GOT_MESSAGE = 1; + + // sent to the controller node when someone sends a message via send(). + const static char SENT_MESSAGE = 2; + + // sent to the controller node when someone enters a call to receive_data() + const static char IN_WAITING_STATE = 3; + + // broadcast when a node terminates itself. + const static char NODE_TERMINATE = 5; + + // broadcast by the controller node when it determines that all nodes are blocked + // on calls to receive_data() and there aren't any messages in flight. This is also + // what makes us go to the next epoch. + const static char SEE_ALL_IN_WAITING_STATE = 6; + + // This isn't ever transmitted between nodes. It is used internally to indicate + // that an error occurred. + const static char READ_ERROR = 7; + + // ------------------------------------------------------------------------------------ + + void read_thread ( + impl1::bsp_con* con, + unsigned long node_id, + unsigned long sender_id, + impl1::thread_safe_message_queue& msg_buffer + ) + { + try + { + while(true) + { + impl1::msg_data msg; + deserialize(msg.msg_type, con->stream); + msg.sender_id = sender_id; + + if (msg.msg_type == MESSAGE_HEADER) + { + msg.data.reset(new std::vector); + deserialize(msg.epoch, con->stream); + deserialize(*msg.data, con->stream); + } + + msg_buffer.push_and_consume(msg); + + if (msg.msg_type == NODE_TERMINATE) + break; + } + } + catch (std::exception& e) + { + impl1::msg_data msg; + msg.data.reset(new std::vector); + vectorstream sout(*msg.data); + sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; + sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; + sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; + sout << " Receiving processing node id: " << node_id << std::endl; + sout << " Error message in the exception: " << e.what() << std::endl; + + msg.sender_id = sender_id; + msg.msg_type = READ_ERROR; + + msg_buffer.push_and_consume(msg); + } + catch (...) + { + impl1::msg_data msg; + msg.data.reset(new std::vector); + vectorstream sout(*msg.data); + sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; + sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; + sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; + sout << " Receiving processing node id: " << node_id << std::endl; + + msg.sender_id = sender_id; + msg.msg_type = READ_ERROR; + + msg_buffer.push_and_consume(msg); + } + } + + // ------------------------------------------------------------------------------------ + + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// IMPLEMENTATION OF bsp_context OBJECT MEMBERS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + close_all_connections_gracefully( + ) + { + if (node_id() != 0) + { + _cons.reset(); + while (_cons.move_next()) + { + // tell the other end that we are intentionally dropping the connection + serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); + _cons.element().value()->stream.flush(); + } + } + + impl1::msg_data msg; + // now wait for all the other nodes to terminate + while (num_terminated_nodes < _cons.size() ) + { + if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0) + { + num_waiting_nodes = 0; + broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); + ++current_epoch; + } + + if (!msg_buffer.pop(msg)) + throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); + + if (msg.msg_type == impl2::NODE_TERMINATE) + { + ++num_terminated_nodes; + _cons[msg.sender_id]->terminated = true; + } + else if (msg.msg_type == impl2::READ_ERROR) + { + throw dlib::socket_error(msg.data_to_string()); + } + else if (msg.msg_type == impl2::MESSAGE_HEADER) + { + throw dlib::socket_error("A BSP node received a message after it has terminated."); + } + else if (msg.msg_type == impl2::GOT_MESSAGE) + { + --num_waiting_nodes; + --outstanding_messages; + } + else if (msg.msg_type == impl2::SENT_MESSAGE) + { + ++outstanding_messages; + } + else if (msg.msg_type == impl2::IN_WAITING_STATE) + { + ++num_waiting_nodes; + } + } + + if (node_id() == 0) + { + _cons.reset(); + while (_cons.move_next()) + { + // tell the other end that we are intentionally dropping the connection + serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); + _cons.element().value()->stream.flush(); + } + + if (outstanding_messages != 0) + { + std::ostringstream sout; + sout << "A BSP job was allowed to terminate before all sent messages have been received.\n"; + sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n"; + sout << "have a corresponding call to receive()."; + throw dlib::socket_error(sout.str()); + } + } + } + +// ---------------------------------------------------------------------------------------- + + bsp_context:: + ~bsp_context() + { + _cons.reset(); + while (_cons.move_next()) + { + _cons.element().value()->con->shutdown(); + } + + msg_buffer.disable(); + + // this will wait for all the threads to terminate + threads.clear(); + } + +// ---------------------------------------------------------------------------------------- + + bsp_context:: + bsp_context( + unsigned long node_id_, + impl1::map_id_to_con& cons_ + ) : + outstanding_messages(0), + num_waiting_nodes(0), + num_terminated_nodes(0), + current_epoch(1), + _cons(cons_), + _node_id(node_id_) + { + // spawn a bunch of read threads, one for each connection + _cons.reset(); + while (_cons.move_next()) + { + std::unique_ptr ptr(new thread_function(&impl2::read_thread, + _cons.element().value().get(), + _node_id, + _cons.element().key(), + ref(msg_buffer))); + threads.push_back(ptr); + } + + } + +// ---------------------------------------------------------------------------------------- + + bool bsp_context:: + receive_data ( + std::shared_ptr >& item, + unsigned long& sending_node_id + ) + { + notify_control_node(impl2::IN_WAITING_STATE); + + while (true) + { + // If there aren't any nodes left to give us messages then return right now. + // We need to check the msg_buffer size to make sure there aren't any + // unprocessed message there. Recall that this can happen because status + // messages always jump to the front of the message buffer. So we might have + // learned about the node terminations before processing their messages for us. + if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0) + { + return false; + } + + // if all running nodes are currently blocking forever on receive_data() + if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size()) + { + num_waiting_nodes = 0; + broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); + + // Note that the reason we have this epoch counter is so we can tell if a + // sent message is from before or after one of these "all nodes waiting" + // synchronization events. If we didn't have the epoch count we would have + // a race condition where one node gets the SEE_ALL_IN_WAITING_STATE + // message before others and then sends out a message to another node + // before that node got the SEE_ALL_IN_WAITING_STATE message. Then that + // node would think the normal message came before SEE_ALL_IN_WAITING_STATE + // which would be bad. + ++current_epoch; + return false; + } + + impl1::msg_data data; + if (!msg_buffer.pop(data, current_epoch)) + throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); + + + switch(data.msg_type) + { + case impl2::MESSAGE_HEADER: { + item = data.data; + sending_node_id = data.sender_id; + notify_control_node(impl2::GOT_MESSAGE); + return true; + } break; + + case impl2::IN_WAITING_STATE: { + ++num_waiting_nodes; + } break; + + case impl2::GOT_MESSAGE: { + --outstanding_messages; + --num_waiting_nodes; + } break; + + case impl2::SENT_MESSAGE: { + ++outstanding_messages; + } break; + + case impl2::NODE_TERMINATE: { + ++num_terminated_nodes; + _cons[data.sender_id]->terminated = true; + } break; + + case impl2::SEE_ALL_IN_WAITING_STATE: { + ++current_epoch; + return false; + } break; + + case impl2::READ_ERROR: { + throw dlib::socket_error(data.data_to_string()); + } break; + + default: { + throw dlib::socket_error("Unknown message received by dlib::bsp_context"); + } break; + } // end switch() + } // end while (true) + } + +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + notify_control_node ( + char val + ) + { + if (node_id() == 0) + { + using namespace impl2; + switch(val) + { + case SENT_MESSAGE: { + ++outstanding_messages; + } break; + + case GOT_MESSAGE: { + --outstanding_messages; + } break; + + case IN_WAITING_STATE: { + // nothing to do in this case + } break; + + default: + DLIB_CASSERT(false,"This should never happen"); + } + } + else + { + serialize(val, _cons[0]->stream); + _cons[0]->stream.flush(); + } + } + +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + broadcast_byte ( + char val + ) + { + for (unsigned long i = 0; i < number_of_nodes(); ++i) + { + // don't send to yourself or to terminated nodes + if (i == node_id() || _cons[i]->terminated) + continue; + + serialize(val, _cons[i]->stream); + _cons[i]->stream.flush(); + } + } + +// ---------------------------------------------------------------------------------------- + + void bsp_context:: + send_data( + const std::vector& item, + unsigned long target_node_id + ) + { + using namespace impl2; + if (_cons[target_node_id]->terminated) + throw socket_error("Attempt to send a message to a node that has terminated."); + + serialize(MESSAGE_HEADER, _cons[target_node_id]->stream); + serialize(current_epoch, _cons[target_node_id]->stream); + serialize(item, _cons[target_node_id]->stream); + _cons[target_node_id]->stream.flush(); + + notify_control_node(SENT_MESSAGE); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BSP_CPph_ + diff --git a/dlib/bsp/bsp.h b/dlib/bsp/bsp.h new file mode 100644 index 0000000000000000000000000000000000000000..f0732c15380ca413850625887b3c8c21708261a4 --- /dev/null +++ b/dlib/bsp/bsp.h @@ -0,0 +1,1043 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BsP_Hh_ +#define DLIB_BsP_Hh_ + +#include "bsp_abstract.h" + +#include +#include +#include + +#include "../sockets.h" +#include "../array.h" +#include "../sockstreambuf.h" +#include "../string.h" +#include "../serialize.h" +#include "../map.h" +#include "../ref.h" +#include "../vectorstream.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl1 + { + inline void null_notify( + unsigned short + ) {} + + struct bsp_con + { + bsp_con( + const network_address& dest + ) : + con(connect(dest)), + buf(con), + stream(&buf), + terminated(false) + { + con->disable_nagle(); + } + + bsp_con( + std::unique_ptr& conptr + ) : + buf(conptr), + stream(&buf), + terminated(false) + { + // make sure we own the connection + conptr.swap(con); + + con->disable_nagle(); + } + + std::unique_ptr con; + sockstreambuf buf; + std::iostream stream; + bool terminated; + }; + + typedef dlib::map >::kernel_1a_c map_id_to_con; + + void connect_all ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id + ); + /*! + ensures + - creates connections to all the given hosts and stores them into cons + !*/ + + void send_out_connection_orders ( + map_id_to_con& cons, + const std::vector& hosts + ); + + // ------------------------------------------------------------------------------------ + + struct hostinfo + { + hostinfo() {} + hostinfo ( + const network_address& addr_, + unsigned long node_id_ + ) : + addr(addr_), + node_id(node_id_) + { + } + + network_address addr; + unsigned long node_id; + }; + + inline void serialize ( + const hostinfo& item, + std::ostream& out + ) + { + dlib::serialize(item.addr, out); + dlib::serialize(item.node_id, out); + } + + inline void deserialize ( + hostinfo& item, + std::istream& in + ) + { + dlib::deserialize(item.addr, in); + dlib::deserialize(item.node_id, in); + } + + // ------------------------------------------------------------------------------------ + + void connect_all_hostinfo ( + map_id_to_con& cons, + const std::vector& hosts, + unsigned long node_id, + std::string& error_string + ); + + // ------------------------------------------------------------------------------------ + + template < + typename port_notify_function_type + > + void listen_and_connect_all( + unsigned long& node_id, + map_id_to_con& cons, + unsigned short port, + port_notify_function_type port_notify_function + ) + { + cons.clear(); + std::unique_ptr list; + const int status = create_listener(list, port); + if (status == PORTINUSE) + { + throw socket_error("Unable to create listening port " + cast_to_string(port) + + ". The port is already in use"); + } + else if (status != 0) + { + throw socket_error("Unable to create listening port " + cast_to_string(port) ); + } + + port_notify_function(list->get_listening_port()); + + std::unique_ptr con; + if (list->accept(con)) + { + throw socket_error("Error occurred while accepting new connection"); + } + + std::unique_ptr temp(new bsp_con(con)); + + unsigned long remote_node_id; + dlib::deserialize(remote_node_id, temp->stream); + dlib::deserialize(node_id, temp->stream); + std::vector targets; + dlib::deserialize(targets, temp->stream); + unsigned long num_incoming_connections; + dlib::deserialize(num_incoming_connections, temp->stream); + + cons.add(remote_node_id,temp); + + // make a thread that will connect to all the targets + map_id_to_con cons2; + std::string error_string; + thread_function thread(connect_all_hostinfo, dlib::ref(cons2), dlib::ref(targets), node_id, dlib::ref(error_string)); + if (error_string.size() != 0) + throw socket_error(error_string); + + // accept any incoming connections + for (unsigned long i = 0; i < num_incoming_connections; ++i) + { + // If it takes more than 10 seconds for the other nodes to connect to us + // then something has gone horribly wrong and it almost certainly will + // never connect at all. So just give up if that happens. + const unsigned long timeout_milliseconds = 10000; + if (list->accept(con, timeout_milliseconds)) + { + throw socket_error("Error occurred while accepting new connection"); + } + + temp.reset(new bsp_con(con)); + + dlib::deserialize(remote_node_id, temp->stream); + cons.add(remote_node_id,temp); + } + + + // put all the connections created by the thread into cons + thread.wait(); + while (cons2.size() > 0) + { + unsigned long id; + std::unique_ptr temp; + cons2.remove_any(id,temp); + cons.add(id,temp); + } + } + + // ------------------------------------------------------------------------------------ + + struct msg_data + { + std::shared_ptr > data; + unsigned long sender_id; + char msg_type; + dlib::uint64 epoch; + + msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {} + + std::string data_to_string() const + { + if (data && data->size() != 0) + return std::string(&(*data)[0], data->size()); + else + return ""; + } + }; + + // ------------------------------------------------------------------------------------ + + class thread_safe_message_queue : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a simple message queue for msg_data objects. Note that it + has the special property that, while messages will generally leave + the queue in the order they are inserted, any message with a smaller + epoch value will always be popped out first. But for all messages + with equal epoch values the queue functions as a normal FIFO queue. + !*/ + private: + struct msg_wrap + { + msg_wrap( + const msg_data& data_, + const dlib::uint64& sequence_number_ + ) : data(data_), sequence_number(sequence_number_) {} + + msg_wrap() : sequence_number(0){} + + msg_data data; + dlib::uint64 sequence_number; + + // Make it so that when msg_wrap objects are in a std::priority_queue, + // messages with a smaller epoch number always come first. Then, within an + // epoch, messages are ordered by their sequence number (so smaller first + // there as well). + bool operator<(const msg_wrap& item) const + { + if (data.epoch < item.data.epoch) + { + return false; + } + else if (data.epoch > item.data.epoch) + { + return true; + } + else + { + if (sequence_number < item.sequence_number) + return false; + else + return true; + } + } + }; + + public: + thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {} + + ~thread_safe_message_queue() + { + disable(); + } + + void disable() + { + auto_mutex lock(class_mutex); + disabled = true; + sig.broadcast(); + } + + unsigned long size() const + { + auto_mutex lock(class_mutex); + return data.size(); + } + + void push_and_consume( msg_data& item) + { + auto_mutex lock(class_mutex); + data.push(msg_wrap(item, next_seq_num++)); + // do this here so that we don't have to worry about different threads touching the shared_ptr. + item.data.reset(); + sig.signal(); + } + + bool pop ( + msg_data& item + ) + /*! + ensures + - if (this function returns true) then + - #item == the next thing from the queue + - else + - this object is disabled + !*/ + { + auto_mutex lock(class_mutex); + while (data.size() == 0 && !disabled) + sig.wait(); + + if (disabled) + return false; + + item = data.top().data; + data.pop(); + + return true; + } + + bool pop ( + msg_data& item, + const dlib::uint64& max_epoch + ) + /*! + ensures + - if (this function returns true) then + - #item == the next thing from the queue that has an epoch <= max_epoch + - else + - this object is disabled + !*/ + { + auto_mutex lock(class_mutex); + while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled) + sig.wait(); + + if (disabled) + return false; + + item = data.top().data; + data.pop(); + + return true; + } + + private: + std::priority_queue data; + dlib::mutex class_mutex; + dlib::signaler sig; + bool disabled; + dlib::uint64 next_seq_num; + }; + + + } + +// ---------------------------------------------------------------------------------------- + + class bsp_context : noncopyable + { + + public: + + template + void send( + const T& item, + unsigned long target_node_id + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(target_node_id < number_of_nodes() && + target_node_id != node_id(), + "\t void bsp_context::send()" + << "\n\t Invalid arguments were given to this function." + << "\n\t target_node_id: " << target_node_id + << "\n\t node_id(): " << node_id() + << "\n\t number_of_nodes(): " << number_of_nodes() + << "\n\t this: " << this + ); + + std::vector buf; + vectorstream sout(buf); + serialize(item, sout); + send_data(buf, target_node_id); + } + + template + void broadcast ( + const T& item + ) + { + std::vector buf; + vectorstream sout(buf); + serialize(item, sout); + for (unsigned long i = 0; i < number_of_nodes(); ++i) + { + // Don't send to yourself. + if (i == node_id()) + continue; + + send_data(buf, i); + } + } + + unsigned long node_id ( + ) const { return _node_id; } + + unsigned long number_of_nodes ( + ) const { return _cons.size()+1; } + + void receive ( + ) + { + unsigned long id; + std::shared_ptr > temp; + if (receive_data(temp,id)) + throw dlib::socket_error("Call to bsp_context::receive() got an unexpected message."); + } + + template + void receive ( + T& item + ) + { + if(!try_receive(item)) + throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); + } + + template + bool try_receive ( + T& item + ) + { + unsigned long sending_node_id; + return try_receive(item, sending_node_id); + } + + template + void receive ( + T& item, + unsigned long& sending_node_id + ) + { + if(!try_receive(item, sending_node_id)) + throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); + } + + template + bool try_receive ( + T& item, + unsigned long& sending_node_id + ) + { + std::shared_ptr > temp; + if (receive_data(temp, sending_node_id)) + { + vectorstream sin(*temp); + deserialize(item, sin); + if (sin.peek() != EOF) + throw serialization_error("deserialize() did not consume all bytes produced by serialize(). " + "This probably means you are calling a receive method with a different type " + "of object than the one which was sent."); + return true; + } + else + { + return false; + } + } + + ~bsp_context(); + + private: + + bsp_context(); + + bsp_context( + unsigned long node_id_, + impl1::map_id_to_con& cons_ + ); + + void close_all_connections_gracefully(); + /*! + ensures + - closes all the connections to other nodes and lets them know that + we are terminating normally rather than as the result of some kind + of error. + !*/ + + bool receive_data ( + std::shared_ptr >& item, + unsigned long& sending_node_id + ); + + + void notify_control_node ( + char val + ); + + void broadcast_byte ( + char val + ); + + void send_data( + const std::vector& item, + unsigned long target_node_id + ); + /*! + requires + - target_node_id < number_of_nodes() + - target_node_id != node_id() + ensures + - sends a copy of item to the node with the given id. + !*/ + + + + + unsigned long outstanding_messages; + unsigned long num_waiting_nodes; + unsigned long num_terminated_nodes; + dlib::uint64 current_epoch; + + impl1::thread_safe_message_queue msg_buffer; + + impl1::map_id_to_con& _cons; + const unsigned long _node_id; + array > threads; + + // ----------------------------------- + + template < + typename funct_type + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct + ); + + template < + typename funct_type, + typename ARG1 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + friend void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + + // ----------------------------------- + + template < + typename port_notify_function_type, + typename funct_type + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + friend void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + + // ----------------------------------- + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + impl1::map_id_to_con cons; + const unsigned long node_id = 0; + connect_all(cons, hosts, node_id); + send_out_connection_orders(cons, hosts); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3,arg4); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(listening_port != 0, + "\t void bsp_listen()" + << "\n\t Invalid arguments were given to this function." + ); + + bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3, arg4); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3); + obj.close_all_connections_gracefully(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ) + { + impl1::map_id_to_con cons; + unsigned long node_id; + listen_and_connect_all(node_id, cons, listening_port, port_notify_function); + bsp_context obj(node_id, cons); + funct(obj,arg1,arg2,arg3,arg4); + obj.close_all_connections_gracefully(); + } +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#ifdef NO_MAKEFILE +#include "bsp.cpp" +#endif + +#endif // DLIB_BsP_Hh_ + diff --git a/dlib/bsp/bsp_abstract.h b/dlib/bsp/bsp_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..b87f3a0c3478ecb6c25f651e99d528148b37c4ee --- /dev/null +++ b/dlib/bsp/bsp_abstract.h @@ -0,0 +1,912 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BsP_ABSTRACT_Hh_ +#ifdef DLIB_BsP_ABSTRACT_Hh_ + +#include "../noncopyable.h" +#include "../sockets/sockets_extensions_abstract.h" +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class bsp_context : noncopyable + { + /*! + WHAT THIS OBJECT REPRESENTS + This is a tool used to implement algorithms using the Bulk Synchronous + Parallel (BSP) computing model. A BSP algorithm is composed of a number of + processing nodes, each executing in parallel. The general flow of + execution in each processing node is the following: + 1. Do work locally on some data. + 2. Send some messages to other nodes. + 3. Receive messages from other nodes. + 4. Go to step 1 or terminate if complete. + + To do this, each processing node needs an API used to send and receive + messages. This API is implemented by the bsp_connect object which provides + these services to a BSP node. + + Note that BSP processing nodes are spawned using the bsp_connect() and + bsp_listen() routines defined at the bottom of this file. For example, to + start a BSP algorithm consisting of N processing nodes, you would make N-1 + calls to bsp_listen() and one call to bsp_connect(). The call to + bsp_connect() then initiates the computation on all nodes. + + Finally, note that there is no explicit barrier synchronization function + you call at the end of step 3. Instead, you can simply call a method such + as try_receive() until it returns false. That is, the bsp_context's + receive methods incorporate a barrier synchronization that happens once all + the BSP nodes are blocked on receive calls and there are no more messages + in flight. + + + THREAD SAFETY + This object is not thread-safe. In particular, you should only ever have + one thread that works with an instance of this object. This means that, + for example, you should not spawn sub-threads from within a BSP processing + node and have them invoke methods on this object. Instead, you should only + invoke this object's methods from within the BSP processing node's main + thread (i.e. the thread that executes the user supplied function funct()). + !*/ + + public: + + template + void send( + const T& item, + unsigned long target_node_id + ); + /*! + requires + - item is serializable + - target_node_id < number_of_nodes() + - target_node_id != node_id() + ensures + - sends a copy of item to the node with the given id. + throws + - dlib::socket_error: + This exception is thrown if there is an error which prevents us from + delivering the message to the given node. One way this might happen is + if the target node has already terminated its execution or has lost + network connectivity. + !*/ + + template + void broadcast ( + const T& item + ); + /*! + ensures + - item is serializable + - sends a copy of item to all other processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents us from + delivering a message to one of the other nodes. This might happen, for + example, if one of the nodes has terminated its execution or has lost + network connectivity. + !*/ + + unsigned long node_id ( + ) const; + /*! + ensures + - Returns the id of the current processing node. That is, + returns a number N such that: + - N < number_of_nodes() + - N == the node id of the processing node that called node_id(). This + is a number that uniquely identifies the processing node. + !*/ + + unsigned long number_of_nodes ( + ) const; + /*! + ensures + - returns the number of processing nodes participating in the BSP + computation. + !*/ + + template + bool try_receive ( + T& item + ); + /*! + requires + - item is serializable + ensures + - if (this function returns true) then + - #item == the next message which was sent to the calling processing + node. + - else + - The following must have been true for this function to return false: + - All other nodes were blocked on calls to receive(), + try_receive(), or have terminated. + - There were not any messages in flight between any nodes. + - That is, if all the nodes had continued to block on receive + methods then they all would have blocked forever. Therefore, + this function only returns false once there are no more messages + to process by any node and there is no possibility of more being + generated until control is returned to the callers of receive + methods. + - When one BSP node's receive method returns because of the above + conditions then all of them will also return. That is, it is NOT the + case that just a subset of BSP nodes unblock. Moreover, they all + unblock at the same time. + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + try_receive(). + !*/ + + template + void receive ( + T& item + ); + /*! + requires + - item is serializable + ensures + - #item == the next message which was sent to the calling processing + node. + - This function is just a wrapper around try_receive() that throws an + exception if a message is not received (i.e. if try_receive() returns + false). + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes or if there was not a message + to receive. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + receive(). + !*/ + + template + bool try_receive ( + T& item, + unsigned long& sending_node_id + ); + /*! + requires + - item is serializable + ensures + - if (this function returns true) then + - #item == the next message which was sent to the calling processing + node. + - #sending_node_id == the node id of the node that sent this message. + - #sending_node_id < number_of_nodes() + - else + - The following must have been true for this function to return false: + - All other nodes were blocked on calls to receive(), + try_receive(), or have terminated. + - There were not any messages in flight between any nodes. + - That is, if all the nodes had continued to block on receive + methods then they all would have blocked forever. Therefore, + this function only returns false once there are no more messages + to process by any node and there is no possibility of more being + generated until control is returned to the callers of receive + methods. + - When one BSP node's receive method returns because of the above + conditions then all of them will also return. That is, it is NOT the + case that just a subset of BSP nodes unblock. Moreover, they all + unblock at the same time. + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + try_receive(). + !*/ + + template + void receive ( + T& item, + unsigned long& sending_node_id + ); + /*! + requires + - item is serializable + ensures + - #item == the next message which was sent to the calling processing node. + - #sending_node_id == the node id of the node that sent this message. + - #sending_node_id < number_of_nodes() + - This function is just a wrapper around try_receive() that throws an + exception if a message is not received (i.e. if try_receive() returns + false). + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes or if there was not a message + to receive. + - dlib::serialization_error or any exception thrown by the global + deserialize(T) routine: + This is thrown if there is a problem in deserialize(). This might + happen if the message sent doesn't match the type T expected by + receive(). + !*/ + + void receive ( + ); + /*! + ensures + - Waits for the following to all be true: + - All other nodes were blocked on calls to receive(), try_receive(), or + have terminated. + - There are not any messages in flight between any nodes. + - That is, if all the nodes had continued to block on receive methods + then they all would have blocked forever. Therefore, this function + only returns once there are no more messages to process by any node + and there is no possibility of more being generated until control is + returned to the callers of receive methods. + - When one BSP node's receive method returns because of the above + conditions then all of them will also return. That is, it is NOT the + case that just a subset of BSP nodes unblock. Moreover, they all unblock + at the same time. + throws + - dlib::socket_error: + This exception is thrown if some error occurs which prevents us from + communicating with other processing nodes or if a message is received + before this function would otherwise return. + + !*/ + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_connect ( + const std::vector& hosts, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function spawns a BSP job consisting of hosts.size()+1 processing nodes. + - The processing node with a node ID of 0 will run locally on the machine + calling bsp_connect(). In particular, this node will execute funct(CONTEXT,arg1,arg2,arg3,arg4), + which is expected to carry out this node's portion of the BSP computation. + - The other processing nodes are executed on the hosts indicated by the input + argument. In particular, this function interprets hosts as a list addresses + identifying machines running the bsp_listen() or bsp_listen_dynamic_port() + routines. + - This call to bsp_connect() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT) will be executed and it will + then be able to participate in the BSP computation as one of the processing + nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1) will be executed and it will + then be able to participate in the BSP computation as one of the processing + nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2) will be executed and + it will then be able to participate in the BSP computation as one of the + processing nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen ( + unsigned short listening_port, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + /*! + requires + - listening_port != 0 + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression + (i.e. funct must be a function or function object) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - This function will listen on TCP port listening_port for a connection from + bsp_connect(). Once the connection is established, it will close the + listening port so it is free for use by other applications. The connection + and BSP computation will continue uninterrupted. + - This call to bsp_listen() blocks until the BSP computation has completed on + all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT) will be executed and it will + then be able to participate in the BSP computation as one of the processing + nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1) will be executed and it + will then be able to participate in the BSP computation as one of the + processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2) will be executed and + it will then be able to participate in the BSP computation as one of the + processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename port_notify_function_type, + typename funct_type, + typename ARG1, + typename ARG2, + typename ARG3, + typename ARG4 + > + void bsp_listen_dynamic_port ( + unsigned short listening_port, + port_notify_function_type port_notify_function, + funct_type funct, + ARG1 arg1, + ARG2 arg2, + ARG3 arg3, + ARG4 arg4 + ); + /*! + requires + - let CONTEXT be an instance of a bsp_context object. Then: + - funct(CONTEXT,arg1,arg2,arg3,arg4) must be a valid expression + (i.e. funct must be a function or function object) + - port_notify_function((unsigned short) 1234) must be a valid expression + (i.e. port_notify_function() must be a function or function object taking an + unsigned short) + ensures + - This function listens for a connection from the bsp_connect() routine. Once + this connection is established, funct(CONTEXT,arg1,arg2,arg3,arg4) will be + executed and it will then be able to participate in the BSP computation as + one of the processing nodes. + - if (listening_port != 0) then + - This function will listen on TCP port listening_port for a connection + from bsp_connect(). + - else + - An available TCP port number is automatically selected and this function + will listen on it for a connection from bsp_connect(). + - Once a listening port is opened, port_notify_function() is called with the + port number used. This provides a mechanism to find out what listening port + has been used if it is automatically selected. It also allows you to find + out when the routine has begun listening for an incoming connection from + bsp_connect(). + - Once a connection is established, we will close the listening port so it is + free for use by other applications. The connection and BSP computation will + continue uninterrupted. + - This call to bsp_listen_dynamic_port() blocks until the BSP computation has + completed on all processing nodes. + throws + - dlib::socket_error + This exception is thrown if there is an error which prevents the BSP + job from executing. + - Any exception thrown by funct() will be propagated out of this call to + bsp_connect(). + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BsP_ABSTRACT_Hh_ + diff --git a/dlib/byte_orderer.h b/dlib/byte_orderer.h new file mode 100644 index 0000000000000000000000000000000000000000..bc8f6108da769eac99aea17cc51b48aa98bdaf51 --- /dev/null +++ b/dlib/byte_orderer.h @@ -0,0 +1,10 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BYTE_ORDEREr_ +#define DLIB_BYTE_ORDEREr_ + + +#include "byte_orderer/byte_orderer_kernel_1.h" + +#endif // DLIB_BYTE_ORDEREr_ + diff --git a/dlib/byte_orderer/byte_orderer_kernel_1.h b/dlib/byte_orderer/byte_orderer_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..9f8e8342f4ecab9078db51cf0903e2143adef9cc --- /dev/null +++ b/dlib/byte_orderer/byte_orderer_kernel_1.h @@ -0,0 +1,176 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BYTE_ORDEREr_KERNEL_1_ +#define DLIB_BYTE_ORDEREr_KERNEL_1_ + +#include "byte_orderer_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" + +namespace dlib +{ + + class byte_orderer + { + /*! + INITIAL VALUE + - if (this machine is little endian) then + - little_endian == true + - else + - little_endian == false + + CONVENTION + - host_is_big_endian() == !little_endian + - host_is_little_endian() == little_endian + + - if (this machine is little endian) then + - little_endian == true + - else + - little_endian == false + + + !*/ + + + public: + + // this is here for backwards compatibility with older versions of dlib. + typedef byte_orderer kernel_1a; + + byte_orderer ( + ) + { + // This will probably never be false but if it is then it means chars are not 8bits + // on this system. Which is a problem for this object. + COMPILE_TIME_ASSERT(sizeof(short) >= 2); + + unsigned long temp = 1; + unsigned char* ptr = reinterpret_cast(&temp); + if (*ptr == 1) + little_endian = true; + else + little_endian = false; + } + + virtual ~byte_orderer ( + ){} + + bool host_is_big_endian ( + ) const { return !little_endian; } + + bool host_is_little_endian ( + ) const { return little_endian; } + + template < + typename T + > + inline void host_to_network ( + T& item + ) const + { if (little_endian) flip(item); } + + template < + typename T + > + inline void network_to_host ( + T& item + ) const { if (little_endian) flip(item); } + + template < + typename T + > + void host_to_big ( + T& item + ) const { if (little_endian) flip(item); } + + template < + typename T + > + void big_to_host ( + T& item + ) const { if (little_endian) flip(item); } + + template < + typename T + > + void host_to_little ( + T& item + ) const { if (!little_endian) flip(item); } + + template < + typename T + > + void little_to_host ( + T& item + ) const { if (!little_endian) flip(item); } + + + private: + + template < + typename T, + size_t size + > + inline void flip ( + T (&array)[size] + ) const + /*! + ensures + - flips the bytes in every element of this array + !*/ + { + for (size_t i = 0; i < size; ++i) + { + flip(array[i]); + } + } + + template < + typename T + > + inline void flip ( + T& item + ) const + /*! + ensures + - reverses the byte ordering in item + !*/ + { + DLIB_ASSERT_HAS_STANDARD_LAYOUT(T); + + T value; + + // If you are getting this as an error then you are probably using + // this object wrong. If you think you aren't then send me (Davis) an + // email and I'll either set you straight or change/remove this check so + // your stuff works :) + COMPILE_TIME_ASSERT(sizeof(T) <= sizeof(long double)); + + // If you are getting a compile error on this line then it means T is + // a pointer type. It doesn't make any sense to byte swap pointers + // since they have no meaning outside the context of their own process. + // So you probably just forgot to dereference that pointer before passing + // it to this function :) + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + + const size_t size = sizeof(T); + unsigned char* const ptr = reinterpret_cast(&item); + unsigned char* const ptr_temp = reinterpret_cast(&value); + for (size_t i = 0; i < size; ++i) + ptr_temp[size-i-1] = ptr[i]; + + item = value; + } + + bool little_endian; + }; + + // make flip not do anything at all for chars + template <> inline void byte_orderer::flip ( char& ) const {} + template <> inline void byte_orderer::flip ( unsigned char& ) const {} + template <> inline void byte_orderer::flip ( signed char& ) const {} +} + +#endif // DLIB_BYTE_ORDEREr_KERNEL_1_ + diff --git a/dlib/byte_orderer/byte_orderer_kernel_abstract.h b/dlib/byte_orderer/byte_orderer_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..f7ea151035f95cfb7b2114f015568e0501458904 --- /dev/null +++ b/dlib/byte_orderer/byte_orderer_kernel_abstract.h @@ -0,0 +1,149 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BYTE_ORDEREr_ABSTRACT_ +#ifdef DLIB_BYTE_ORDEREr_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + + class byte_orderer + { + /*! + INITIAL VALUE + This object has no state. + + WHAT THIS OBJECT REPRESENTS + This object simply provides a mechanism to convert data from a + host machine's own byte ordering to big or little endian and to + also do the reverse. + + It also provides a pair of functions to convert to/from network byte + order where network byte order is big endian byte order. This pair of + functions does the exact same thing as the host_to_big() and big_to_host() + functions and is provided simply so that client code can use the most + self documenting name appropriate. + + Also note that this object is capable of correctly flipping the contents + of arrays when the arrays are declared on the stack. e.g. You can + say things like: + int array[10]; + bo.host_to_network(array); + !*/ + + public: + + byte_orderer ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~byte_orderer ( + ); + /*! + ensures + - any resources associated with *this have been released + !*/ + + bool host_is_big_endian ( + ) const; + /*! + ensures + - if (the host computer is a big endian machine) then + - returns true + - else + - returns false + !*/ + + bool host_is_little_endian ( + ) const; + /*! + ensures + - if (the host computer is a little endian machine) then + - returns true + - else + - returns false + !*/ + + template < + typename T + > + void host_to_network ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from host byte order + to network byte order. + !*/ + + template < + typename T + > + void network_to_host ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from network byte order + to host byte order. + !*/ + + template < + typename T + > + void host_to_big ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from host byte order + to big endian byte order. + !*/ + + template < + typename T + > + void big_to_host ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from big endian byte order + to host byte order. + !*/ + + template < + typename T + > + void host_to_little ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from host byte order + to little endian byte order. + !*/ + + template < + typename T + > + void little_to_host ( + T& item + ) const; + /*! + ensures + - #item == the value of item converted from little endian byte order + to host byte order. + !*/ + + }; +} + +#endif // DLIB_BYTE_ORDEREr_ABSTRACT_ + diff --git a/dlib/cassert b/dlib/cassert new file mode 100644 index 0000000000000000000000000000000000000000..eb0e59e4176001d73c2070b3dd7cb2a6f9677eef --- /dev/null +++ b/dlib/cassert @@ -0,0 +1 @@ +#include "dlib_include_path_tutorial.txt" diff --git a/dlib/clustering.h b/dlib/clustering.h new file mode 100644 index 0000000000000000000000000000000000000000..3cbd6cfd431171ae275561813abbc84d73ad2937 --- /dev/null +++ b/dlib/clustering.h @@ -0,0 +1,13 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CLuSTERING_ +#define DLIB_CLuSTERING_ + +#include "clustering/modularity_clustering.h" +#include "clustering/chinese_whispers.h" +#include "clustering/spectral_cluster.h" +#include "clustering/bottom_up_cluster.h" +#include "svm/kkmeans.h" + +#endif // DLIB_CLuSTERING_ + diff --git a/dlib/clustering/bottom_up_cluster.h b/dlib/clustering/bottom_up_cluster.h new file mode 100644 index 0000000000000000000000000000000000000000..f80b651087dcda0d3d6e050208142e84c7fd1bc1 --- /dev/null +++ b/dlib/clustering/bottom_up_cluster.h @@ -0,0 +1,253 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_BOTTOM_uP_CLUSTER_Hh_ +#define DLIB_BOTTOM_uP_CLUSTER_Hh_ + +#include +#include + +#include "bottom_up_cluster_abstract.h" +#include "../algs.h" +#include "../matrix.h" +#include "../disjoint_subsets.h" +#include "../graph_utils.h" + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace buc_impl + { + inline void merge_sets ( + matrix& dists, + unsigned long dest, + unsigned long src + ) + { + for (long r = 0; r < dists.nr(); ++r) + dists(dest,r) = dists(r,dest) = std::max(dists(r,dest), dists(r,src)); + } + + struct compare_dist + { + bool operator() ( + const sample_pair& a, + const sample_pair& b + ) const + { + return a.distance() > b.distance(); + } + }; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + unsigned long bottom_up_cluster ( + const matrix_exp& dists_, + std::vector& labels, + unsigned long min_num_clusters, + double max_dist = std::numeric_limits::infinity() + ) + { + matrix dists = matrix_cast(dists_); + // make sure requires clause is not broken + DLIB_CASSERT(dists.nr() == dists.nc() && min_num_clusters > 0, + "\t unsigned long bottom_up_cluster()" + << "\n\t Invalid inputs were given to this function." + << "\n\t dists.nr(): " << dists.nr() + << "\n\t dists.nc(): " << dists.nc() + << "\n\t min_num_clusters: " << min_num_clusters + ); + + using namespace buc_impl; + + labels.resize(dists.nr()); + disjoint_subsets sets; + sets.set_size(dists.nr()); + if (labels.size() == 0) + return 0; + + // push all the edges in the graph into a priority queue so the best edges to merge + // come first. + std::priority_queue, compare_dist> que; + for (long r = 0; r < dists.nr(); ++r) + for (long c = r+1; c < dists.nc(); ++c) + que.push(sample_pair(r,c,dists(r,c))); + + // Now start merging nodes. + for (unsigned long iter = min_num_clusters; iter < sets.size(); ++iter) + { + // find the next best thing to merge. + double best_dist = que.top().distance(); + unsigned long a = sets.find_set(que.top().index1()); + unsigned long b = sets.find_set(que.top().index2()); + que.pop(); + // we have been merging and modifying the distances, so make sure this distance + // is still valid and these guys haven't been merged already. + while(a == b || best_dist < dists(a,b)) + { + // Haven't merged it yet, so put it back in with updated distance for + // reconsideration later. + if (a != b) + que.push(sample_pair(a, b, dists(a, b))); + + best_dist = que.top().distance(); + a = sets.find_set(que.top().index1()); + b = sets.find_set(que.top().index2()); + que.pop(); + } + + + // now merge these sets if the best distance is small enough + if (best_dist > max_dist) + break; + unsigned long news = sets.merge_sets(a,b); + unsigned long olds = (news==a)?b:a; + merge_sets(dists, news, olds); + } + + // figure out which cluster each element is in. Also make sure the labels are + // contiguous. + std::map relabel; + for (unsigned long r = 0; r < labels.size(); ++r) + { + unsigned long l = sets.find_set(r); + // relabel to make contiguous + if (relabel.count(l) == 0) + { + unsigned long next = relabel.size(); + relabel[l] = next; + } + labels[r] = relabel[l]; + } + + + return relabel.size(); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct snl_range + { + snl_range() = default; + snl_range(double val) : lower(val), upper(val) {} + snl_range(double l, double u) : lower(l), upper(u) { DLIB_ASSERT(lower <= upper)} + + double lower = 0; + double upper = 0; + + double width() const { return upper-lower; } + bool operator<(const snl_range& item) const { return lower < item.lower; } + }; + + inline snl_range merge(const snl_range& a, const snl_range& b) + { + return snl_range(std::min(a.lower, b.lower), std::max(a.upper, b.upper)); + } + + inline double distance (const snl_range& a, const snl_range& b) + { + return std::max(a.lower,b.lower) - std::min(a.upper,b.upper); + } + + inline std::ostream& operator<< (std::ostream& out, const snl_range& item ) + { + out << "["< segment_number_line ( + const std::vector& x, + const double max_range_width + ) + { + DLIB_CASSERT(max_range_width >= 0); + + // create initial ranges, one for each value in x. So initially, all the ranges have + // width of 0. + std::vector ranges; + for (auto v : x) + ranges.push_back(v); + std::sort(ranges.begin(), ranges.end()); + + std::vector greedy_final_ranges; + if (ranges.size() == 0) + return greedy_final_ranges; + // We will try two different clustering strategies. One that does a simple greedy left + // to right sweep and another that does a bottom up agglomerative clustering. This + // first loop runs the greedy left to right sweep. Then at the end of this routine we + // will return the results that produced the tightest clustering. + greedy_final_ranges.push_back(ranges[0]); + for (size_t i = 1; i < ranges.size(); ++i) + { + auto m = merge(greedy_final_ranges.back(), ranges[i]); + if (m.width() <= max_range_width) + greedy_final_ranges.back() = m; + else + greedy_final_ranges.push_back(ranges[i]); + } + + + // Here we do the bottom up clustering. So compute the edges connecting our ranges. + // We will simply say there are edges between ranges if and only if they are + // immediately adjacent on the number line. + std::vector edges; + for (size_t i = 1; i < ranges.size(); ++i) + edges.push_back(sample_pair(i-1,i, distance(ranges[i-1],ranges[i]))); + std::sort(edges.begin(), edges.end(), order_by_distance); + + disjoint_subsets sets; + sets.set_size(ranges.size()); + + // Now start merging nodes. + for (auto edge : edges) + { + // find the next best thing to merge. + unsigned long a = sets.find_set(edge.index1()); + unsigned long b = sets.find_set(edge.index2()); + + // merge it if it doesn't result in an interval that's too big. + auto m = merge(ranges[a], ranges[b]); + if (m.width() <= max_range_width) + { + unsigned long news = sets.merge_sets(a,b); + ranges[news] = m; + } + } + + // Now create a list of the final ranges. We will do this by keeping track of which + // range we already added to final_ranges. + std::vector final_ranges; + std::vector already_output(ranges.size(), false); + for (unsigned long i = 0; i < sets.size(); ++i) + { + auto s = sets.find_set(i); + if (!already_output[s]) + { + final_ranges.push_back(ranges[s]); + already_output[s] = true; + } + } + + // only use the greedy clusters if they found a clustering with fewer clusters. + // Otherwise, the bottom up clustering probably produced a more sensible clustering. + if (final_ranges.size() <= greedy_final_ranges.size()) + return final_ranges; + else + return greedy_final_ranges; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOTTOM_uP_CLUSTER_Hh_ + diff --git a/dlib/clustering/bottom_up_cluster_abstract.h b/dlib/clustering/bottom_up_cluster_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..72d362c12574ec3e1ee457232cea481551b600a5 --- /dev/null +++ b/dlib/clustering/bottom_up_cluster_abstract.h @@ -0,0 +1,136 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ +#ifdef DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ + +#include "../matrix.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename EXP + > + unsigned long bottom_up_cluster ( + const matrix_exp& dists, + std::vector& labels, + unsigned long min_num_clusters, + double max_dist = std::numeric_limits::infinity() + ); + /*! + requires + - dists.nr() == dists.nc() + - min_num_clusters > 0 + - dists == trans(dists) + (l.e. dists should be symmetric) + ensures + - Runs a bottom up agglomerative clustering algorithm. + - Interprets dists as a matrix that gives the distances between dists.nr() + items. In particular, we take dists(i,j) to be the distance between the ith + and jth element of some set. This function clusters the elements of this set + into at least min_num_clusters (or dists.nr() if there aren't enough + elements). Additionally, within each cluster, the maximum pairwise distance + between any two cluster elements is <= max_dist. + - returns the number of clusters found. + - #labels.size() == dists.nr() + - for all valid i: + - #labels[i] == the cluster ID of the node with index i (i.e. the node + corresponding to the distances dists(i,*)). + - 0 <= #labels[i] < the number of clusters found + (i.e. cluster IDs are assigned contiguously and start at 0) + !*/ + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + struct snl_range + { + /*! + WHAT THIS OBJECT REPRESENTS + This object represents an interval on the real number line. It is used + to store the outputs of the segment_number_line() routine defined below. + !*/ + + snl_range( + ); + /*! + ensures + - #lower == 0 + - #upper == 0 + !*/ + + snl_range( + double val + ); + /*! + ensures + - #lower == val + - #upper == val + !*/ + + snl_range( + double l, + double u + ); + /*! + requires + - l <= u + ensures + - #lower == l + - #upper == u + !*/ + + double lower; + double upper; + + double width( + ) const { return upper-lower; } + /*! + ensures + - returns the width of this interval on the number line. + !*/ + + bool operator<(const snl_range& item) const { return lower < item.lower; } + /*! + ensures + - provides a total ordering of snl_range objects assuming they are + non-overlapping. + !*/ + }; + + std::ostream& operator<< (std::ostream& out, const snl_range& item ); + /*! + ensures + - prints item to out in the form [lower,upper]. + !*/ + +// ---------------------------------------------------------------------------------------- + + std::vector segment_number_line ( + const std::vector& x, + const double max_range_width + ); + /*! + requires + - max_range_width >= 0 + ensures + - Finds a clustering of the values in x and returns the ranges that define the + clustering. This routine uses a combination of bottom up clustering and a + simple greedy scan to try and find the most compact set of ranges that + contain all the values in x. + - This routine has approximately linear runtime. + - Every value in x will be contained inside one of the returned snl_range + objects; + - All returned snl_range object's will have a width() <= max_range_width and + will also be non-overlapping. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_BOTTOM_uP_CLUSTER_ABSTRACT_Hh_ + + diff --git a/dlib/clustering/chinese_whispers.h b/dlib/clustering/chinese_whispers.h new file mode 100644 index 0000000000000000000000000000000000000000..332cce1a08c046d1e4d5f59d02c06392ddcd6884 --- /dev/null +++ b/dlib/clustering/chinese_whispers.h @@ -0,0 +1,135 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CHINESE_WHISPErS_Hh_ +#define DLIB_CHINESE_WHISPErS_Hh_ + +#include "chinese_whispers_abstract.h" +#include +#include "../rand.h" +#include "../graph_utils/edge_list_graphs.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ordered_by_index(edges), + "\t unsigned long chinese_whispers()" + << "\n\t Invalid inputs were given to this function" + ); + + labels.clear(); + if (edges.size() == 0) + return 0; + + std::vector > neighbors; + find_neighbor_ranges(edges, neighbors); + + // Initialize the labels, each node gets a different label. + labels.resize(neighbors.size()); + for (unsigned long i = 0; i < labels.size(); ++i) + labels[i] = i; + + + for (unsigned long iter = 0; iter < neighbors.size()*num_iterations; ++iter) + { + // Pick a random node. + const unsigned long idx = rnd.get_random_64bit_number()%neighbors.size(); + + // Count how many times each label happens amongst our neighbors. + std::map labels_to_counts; + const unsigned long end = neighbors[idx].second; + for (unsigned long i = neighbors[idx].first; i != end; ++i) + { + labels_to_counts[labels[edges[i].index2()]] += edges[i].distance(); + } + + // find the most common label + std::map::iterator i; + double best_score = -std::numeric_limits::infinity(); + unsigned long best_label = labels[idx]; + for (i = labels_to_counts.begin(); i != labels_to_counts.end(); ++i) + { + if (i->second > best_score) + { + best_score = i->second; + best_label = i->first; + } + } + + labels[idx] = best_label; + } + + + // Remap the labels into a contiguous range. First we find the + // mapping. + std::map label_remap; + for (unsigned long i = 0; i < labels.size(); ++i) + { + const unsigned long next_id = label_remap.size(); + if (label_remap.count(labels[i]) == 0) + label_remap[labels[i]] = next_id; + } + // now apply the mapping to all the labels. + for (unsigned long i = 0; i < labels.size(); ++i) + { + labels[i] = label_remap[labels[i]]; + } + + return label_remap.size(); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ) + { + std::vector oedges; + convert_unordered_to_ordered(edges, oedges); + std::sort(oedges.begin(), oedges.end(), &order_by_index); + + return chinese_whispers(oedges, labels, num_iterations, rnd); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ) + { + dlib::rand rnd; + return chinese_whispers(edges, labels, num_iterations, rnd); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ) + { + dlib::rand rnd; + return chinese_whispers(edges, labels, num_iterations, rnd); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CHINESE_WHISPErS_Hh_ + diff --git a/dlib/clustering/chinese_whispers_abstract.h b/dlib/clustering/chinese_whispers_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..7a184c6f9449a38a6ac7140c325e5f3264f80551 --- /dev/null +++ b/dlib/clustering/chinese_whispers_abstract.h @@ -0,0 +1,97 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ +#ifdef DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ + +#include +#include "../rand.h" +#include "../graph_utils/ordered_sample_pair_abstract.h" +#include "../graph_utils/sample_pair_abstract.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ); + /*! + requires + - is_ordered_by_index(edges) == true + ensures + - This function implements the graph clustering algorithm described in the + paper: Chinese Whispers - an Efficient Graph Clustering Algorithm and its + Application to Natural Language Processing Problems by Chris Biemann. + - Interprets edges as a directed graph. That is, it contains the edges on the + said graph and the ordered_sample_pair::distance() values define the edge + weights (larger values indicating a stronger edge connection between the + nodes). If an edge has a distance() value of infinity then it is considered + a "must link" edge. + - returns the number of clusters found. + - #labels.size() == max_index_plus_one(edges) + - for all valid i: + - #labels[i] == the cluster ID of the node with index i in the graph. + - 0 <= #labels[i] < the number of clusters found + (i.e. cluster IDs are assigned contiguously and start at 0) + - Duplicate edges are interpreted as if there had been just one edge with a + distance value equal to the sum of all the duplicate edge's distance values. + - The algorithm performs exactly num_iterations passes over the graph before + terminating. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations, + dlib::rand& rnd + ); + /*! + ensures + - This function is identical to the above chinese_whispers() routine except + that it operates on a vector of sample_pair objects instead of + ordered_sample_pairs. Therefore, this is simply a convenience routine. In + particular, it is implemented by transforming the given edges into + ordered_sample_pairs and then calling the chinese_whispers() routine defined + above. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ); + /*! + requires + - is_ordered_by_index(edges) == true + ensures + - performs: return chinese_whispers(edges, labels, num_iterations, rnd) + where rnd is a default initialized dlib::rand object. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long chinese_whispers ( + const std::vector& edges, + std::vector& labels, + const unsigned long num_iterations = 100 + ); + /*! + ensures + - performs: return chinese_whispers(edges, labels, num_iterations, rnd) + where rnd is a default initialized dlib::rand object. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CHINESE_WHISPErS_ABSTRACT_Hh_ + diff --git a/dlib/clustering/modularity_clustering.h b/dlib/clustering/modularity_clustering.h new file mode 100644 index 0000000000000000000000000000000000000000..8b8a0b0a58b75b307c5a4f158cac4605ef0c299d --- /dev/null +++ b/dlib/clustering/modularity_clustering.h @@ -0,0 +1,515 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_MODULARITY_ClUSTERING__H__ +#define DLIB_MODULARITY_ClUSTERING__H__ + +#include "modularity_clustering_abstract.h" +#include "../sparse_vector.h" +#include "../graph_utils/edge_list_graphs.h" +#include "../matrix.h" +#include "../rand.h" + +namespace dlib +{ + +// ----------------------------------------------------------------------------------------- + + namespace impl + { + inline double newman_cluster_split ( + dlib::rand& rnd, + const std::vector& edges, + const matrix& node_degrees, // k from the Newman paper + const matrix& Bdiag, // diag(B) from the Newman paper + const double& edge_sum, // m from the Newman paper + matrix& labels, + const double eps, + const unsigned long max_iterations + ) + /*! + requires + - node_degrees.size() == max_index_plus_one(edges) + - Bdiag.size() == max_index_plus_one(edges) + - edges must be sorted according to order_by_index() + ensures + - This routine splits a graph into two subgraphs using the Newman + clustering method. + - returns the modularity obtained when the graph is split according + to the contents of #labels. + - #labels.size() == node_degrees.size() + - for all valid i: #labels(i) == -1 or +1 + - if (this function returns 0) then + - all the labels are equal, i.e. the graph is not split. + !*/ + { + // Scale epsilon so that it is relative to the expected value of an element of a + // unit vector of length node_degrees.size(). + const double power_iter_eps = eps * std::sqrt(1.0/node_degrees.size()); + + // Make a random unit vector and put in labels. + labels.set_size(node_degrees.size()); + for (long i = 0; i < labels.size(); ++i) + labels(i) = rnd.get_random_gaussian(); + labels /= length(labels); + + matrix Bv, Bv_unit; + + // Do the power iteration for a while. + double eig = -1; + double offset = 0; + while (eig < 0) + { + + // any number larger than power_iter_eps + double iteration_change = power_iter_eps*2+1; + for (unsigned long i = 0; i < max_iterations && iteration_change > power_iter_eps; ++i) + { + sparse_matrix_vector_multiply(edges, labels, Bv); + Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees; + + if (offset != 0) + { + Bv -= offset*labels; + } + + + const double len = length(Bv); + if (len != 0) + { + Bv_unit = Bv/len; + iteration_change = max(abs(labels-Bv_unit)); + labels.swap(Bv_unit); + } + else + { + // Had a bad time, pick another random vector and try it with the + // power iteration. + for (long i = 0; i < labels.size(); ++i) + labels(i) = rnd.get_random_gaussian(); + } + } + + eig = dot(Bv,labels); + // we will repeat this loop if the largest eigenvalue is negative + offset = eig; + } + + + for (long i = 0; i < labels.size(); ++i) + { + if (labels(i) > 0) + labels(i) = 1; + else + labels(i) = -1; + } + + + // compute B*labels, store result in Bv. + sparse_matrix_vector_multiply(edges, labels, Bv); + Bv -= dot(node_degrees, labels)/(2*edge_sum) * node_degrees; + + // Do some label refinement. In this step we swap labels if it + // improves the modularity score. + bool flipped_label = true; + while(flipped_label) + { + flipped_label = false; + unsigned long idx = 0; + for (long i = 0; i < labels.size(); ++i) + { + const double val = -2*labels(i); + const double increase = 4*Bdiag(i) + 2*val*Bv(i); + + // if there is an increase in modularity for swapping this label + if (increase > 0) + { + labels(i) *= -1; + while (idx < edges.size() && edges[idx].index1() == (unsigned long)i) + { + const long j = edges[idx].index2(); + Bv(j) += val*edges[idx].distance(); + ++idx; + } + + Bv -= (val*node_degrees(i)/(2*edge_sum))*node_degrees; + + flipped_label = true; + } + else + { + while (idx < edges.size() && edges[idx].index1() == (unsigned long)i) + { + ++idx; + } + } + } + } + + + const double modularity = dot(Bv, labels)/(4*edge_sum); + + return modularity; + } + + // ------------------------------------------------------------------------------------- + + inline unsigned long newman_cluster_helper ( + dlib::rand& rnd, + const std::vector& edges, + const matrix& node_degrees, // k from the Newman paper + const matrix& Bdiag, // diag(B) from the Newman paper + const double& edge_sum, // m from the Newman paper + std::vector& labels, + double modularity_threshold, + const double eps, + const unsigned long max_iterations + ) + /*! + ensures + - returns the number of clusters the data was split into + !*/ + { + matrix l; + const double modularity = newman_cluster_split(rnd,edges,node_degrees,Bdiag,edge_sum,l,eps,max_iterations); + + + // We need to collapse the node index values down to contiguous values. So + // we use the following two vectors to contain the mappings from input index + // values to their corresponding index values in each split. + std::vector left_idx_map(node_degrees.size()); + std::vector right_idx_map(node_degrees.size()); + + // figure out how many nodes went into each side of the split. + unsigned long num_left_split = 0; + unsigned long num_right_split = 0; + for (long i = 0; i < l.size(); ++i) + { + if (l(i) > 0) + { + left_idx_map[i] = num_left_split; + ++num_left_split; + } + else + { + right_idx_map[i] = num_right_split; + ++num_right_split; + } + } + + // do a recursive split if it will improve the modularity. + if (modularity > modularity_threshold && num_left_split > 0 && num_right_split > 0) + { + + // split the node_degrees and Bdiag matrices into left and right split parts + matrix left_node_degrees(num_left_split); + matrix right_node_degrees(num_right_split); + matrix left_Bdiag(num_left_split); + matrix right_Bdiag(num_right_split); + for (long i = 0; i < l.size(); ++i) + { + if (l(i) > 0) + { + left_node_degrees(left_idx_map[i]) = node_degrees(i); + left_Bdiag(left_idx_map[i]) = Bdiag(i); + } + else + { + right_node_degrees(right_idx_map[i]) = node_degrees(i); + right_Bdiag(right_idx_map[i]) = Bdiag(i); + } + } + + + // put the edges from one side of the split into split_edges + std::vector split_edges; + modularity_threshold = 0; + for (unsigned long k = 0; k < edges.size(); ++k) + { + const unsigned long i = edges[k].index1(); + const unsigned long j = edges[k].index2(); + const double d = edges[k].distance(); + if (l(i) > 0 && l(j) > 0) + { + split_edges.push_back(ordered_sample_pair(left_idx_map[i], left_idx_map[j], d)); + modularity_threshold += d; + } + } + modularity_threshold -= sum(left_node_degrees*sum(left_node_degrees))/(2*edge_sum); + modularity_threshold /= 4*edge_sum; + + unsigned long num_left_clusters; + std::vector left_labels; + num_left_clusters = newman_cluster_helper(rnd,split_edges,left_node_degrees,left_Bdiag, + edge_sum,left_labels,modularity_threshold, + eps, max_iterations); + + // now load the other side into split_edges and cluster it as well + split_edges.clear(); + modularity_threshold = 0; + for (unsigned long k = 0; k < edges.size(); ++k) + { + const unsigned long i = edges[k].index1(); + const unsigned long j = edges[k].index2(); + const double d = edges[k].distance(); + if (l(i) < 0 && l(j) < 0) + { + split_edges.push_back(ordered_sample_pair(right_idx_map[i], right_idx_map[j], d)); + modularity_threshold += d; + } + } + modularity_threshold -= sum(right_node_degrees*sum(right_node_degrees))/(2*edge_sum); + modularity_threshold /= 4*edge_sum; + + unsigned long num_right_clusters; + std::vector right_labels; + num_right_clusters = newman_cluster_helper(rnd,split_edges,right_node_degrees,right_Bdiag, + edge_sum,right_labels,modularity_threshold, + eps, max_iterations); + + // Now merge the labels from the two splits. + labels.resize(node_degrees.size()); + for (unsigned long i = 0; i < labels.size(); ++i) + { + // if this node was in the left split + if (l(i) > 0) + { + labels[i] = left_labels[left_idx_map[i]]; + } + else // if this node was in the right split + { + labels[i] = right_labels[right_idx_map[i]] + num_left_clusters; + } + } + + + return num_left_clusters + num_right_clusters; + } + else + { + labels.assign(node_degrees.size(),0); + return 1; + } + + } + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(is_ordered_by_index(edges), + "\t unsigned long newman_cluster()" + << "\n\t Invalid inputs were given to this function" + ); + + labels.clear(); + if (edges.size() == 0) + return 0; + + const unsigned long num_nodes = max_index_plus_one(edges); + + // compute the node_degrees vector, edge_sum value, and diag(B). + matrix node_degrees(num_nodes); + matrix Bdiag(num_nodes); + Bdiag = 0; + double edge_sum = 0; + node_degrees = 0; + for (unsigned long i = 0; i < edges.size(); ++i) + { + node_degrees(edges[i].index1()) += edges[i].distance(); + edge_sum += edges[i].distance(); + if (edges[i].index1() == edges[i].index2()) + Bdiag(edges[i].index1()) += edges[i].distance(); + } + edge_sum /= 2; + Bdiag -= squared(node_degrees)/(2*edge_sum); + + + dlib::rand rnd; + return impl::newman_cluster_helper(rnd,edges,node_degrees,Bdiag,edge_sum,labels,0,eps,max_iterations); + } + +// ---------------------------------------------------------------------------------------- + + inline unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ) + { + std::vector oedges; + convert_unordered_to_ordered(edges, oedges); + std::sort(oedges.begin(), oedges.end(), &order_by_index); + + return newman_cluster(oedges, labels, eps, max_iterations); + } + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + inline std::vector remap_labels ( + const std::vector& labels, + unsigned long& num_labels + ) + /*! + ensures + - This function takes labels and produces a mapping which maps elements of + labels into the most compact range in [0, max] as possible. In particular, + there won't be any unused integers in the mapped range. + - #num_labels == the number of distinct values in labels. + - returns a vector V such that: + - V.size() == labels.size() + - max(mat(V))+1 == num_labels. + - for all valid i,j: + - if (labels[i] == labels[j]) then + - V[i] == V[j] + - else + - V[i] != V[j] + !*/ + { + std::map temp; + for (unsigned long i = 0; i < labels.size(); ++i) + { + if (temp.count(labels[i]) == 0) + { + const unsigned long next = temp.size(); + temp[labels[i]] = next; + } + } + + num_labels = temp.size(); + + std::vector result(labels.size()); + for (unsigned long i = 0; i < labels.size(); ++i) + { + result[i] = temp[labels[i]]; + } + return result; + } + } + +// ---------------------------------------------------------------------------------------- + + inline double modularity ( + const std::vector& edges, + const std::vector& labels + ) + { + const unsigned long num_nodes = max_index_plus_one(edges); + // make sure requires clause is not broken + DLIB_ASSERT(labels.size() == num_nodes, + "\t double modularity()" + << "\n\t Invalid inputs were given to this function" + ); + + unsigned long num_labels; + const std::vector& labels_ = dlib::impl::remap_labels(labels,num_labels); + + std::vector cluster_sums(num_labels,0); + std::vector k(num_nodes,0); + + double Q = 0; + double m = 0; + for (unsigned long i = 0; i < edges.size(); ++i) + { + const unsigned long n1 = edges[i].index1(); + const unsigned long n2 = edges[i].index2(); + k[n1] += edges[i].distance(); + if (n1 != n2) + k[n2] += edges[i].distance(); + + if (n1 != n2) + m += edges[i].distance(); + else + m += edges[i].distance()/2; + + if (labels_[n1] == labels_[n2]) + { + if (n1 != n2) + Q += 2*edges[i].distance(); + else + Q += edges[i].distance(); + } + } + + if (m == 0) + return 0; + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + cluster_sums[labels_[i]] += k[i]; + } + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + Q -= k[i]*cluster_sums[labels_[i]]/(2*m); + } + + return 1.0/(2*m)*Q; + } + +// ---------------------------------------------------------------------------------------- + + inline double modularity ( + const std::vector& edges, + const std::vector& labels + ) + { + const unsigned long num_nodes = max_index_plus_one(edges); + // make sure requires clause is not broken + DLIB_ASSERT(labels.size() == num_nodes, + "\t double modularity()" + << "\n\t Invalid inputs were given to this function" + ); + + + unsigned long num_labels; + const std::vector& labels_ = dlib::impl::remap_labels(labels,num_labels); + + std::vector cluster_sums(num_labels,0); + std::vector k(num_nodes,0); + + double Q = 0; + double m = 0; + for (unsigned long i = 0; i < edges.size(); ++i) + { + const unsigned long n1 = edges[i].index1(); + const unsigned long n2 = edges[i].index2(); + k[n1] += edges[i].distance(); + m += edges[i].distance(); + if (labels_[n1] == labels_[n2]) + { + Q += edges[i].distance(); + } + } + + if (m == 0) + return 0; + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + cluster_sums[labels_[i]] += k[i]; + } + + for (unsigned long i = 0; i < labels_.size(); ++i) + { + Q -= k[i]*cluster_sums[labels_[i]]/m; + } + + return 1.0/m*Q; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MODULARITY_ClUSTERING__H__ + diff --git a/dlib/clustering/modularity_clustering_abstract.h b/dlib/clustering/modularity_clustering_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..c1e7c20c41304230cf992b1326e1f5364a41aacf --- /dev/null +++ b/dlib/clustering/modularity_clustering_abstract.h @@ -0,0 +1,125 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ +#ifdef DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ + +#include +#include "../graph_utils/ordered_sample_pair_abstract.h" +#include "../graph_utils/sample_pair_abstract.h" + +namespace dlib +{ + +// ----------------------------------------------------------------------------------------- + + double modularity ( + const std::vector& edges, + const std::vector& labels + ); + /*! + requires + - labels.size() == max_index_plus_one(edges) + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - Interprets edges as an undirected graph. That is, it contains the edges on + the said graph and the sample_pair::distance() values define the edge weights + (larger values indicating a stronger edge connection between the nodes). + - This function returns the modularity value obtained when the given input + graph is broken into subgraphs according to the contents of labels. In + particular, we say that two nodes with indices i and j are in the same + subgraph or community if and only if labels[i] == labels[j]. + - Duplicate edges are interpreted as if there had been just one edge with a + distance value equal to the sum of all the duplicate edge's distance values. + - See the paper Modularity and community structure in networks by M. E. J. Newman + for a detailed definition. + !*/ + +// ---------------------------------------------------------------------------------------- + + double modularity ( + const std::vector& edges, + const std::vector& labels + ); + /*! + requires + - labels.size() == max_index_plus_one(edges) + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - Interprets edges as a directed graph. That is, it contains the edges on the + said graph and the ordered_sample_pair::distance() values define the edge + weights (larger values indicating a stronger edge connection between the + nodes). Note that, generally, modularity is only really defined for + undirected graphs. Therefore, the "directed graph" given to this function + should have symmetric edges between all nodes. The reason this function is + provided at all is because sometimes a vector of ordered_sample_pair objects + is a useful representation of an undirected graph. + - This function returns the modularity value obtained when the given input + graph is broken into subgraphs according to the contents of labels. In + particular, we say that two nodes with indices i and j are in the same + subgraph or community if and only if labels[i] == labels[j]. + - Duplicate edges are interpreted as if there had been just one edge with a + distance value equal to the sum of all the duplicate edge's distance values. + - See the paper Modularity and community structure in networks by M. E. J. Newman + for a detailed definition. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ); + /*! + requires + - is_ordered_by_index(edges) == true + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - This function performs the clustering algorithm described in the paper + Modularity and community structure in networks by M. E. J. Newman. + - This function interprets edges as a graph and attempts to find the labeling + that maximizes modularity(edges, #labels). + - returns the number of clusters found. + - #labels.size() == max_index_plus_one(edges) + - for all valid i: + - #labels[i] == the cluster ID of the node with index i in the graph. + - 0 <= #labels[i] < the number of clusters found + (i.e. cluster IDs are assigned contiguously and start at 0) + - The main computation of the algorithm is involved in finding an eigenvector + of a certain matrix. To do this, we use the power iteration. In particular, + each time we try to find an eigenvector we will let the power iteration loop + at most max_iterations times or until it reaches an accuracy of eps. + Whichever comes first. + !*/ + +// ---------------------------------------------------------------------------------------- + + unsigned long newman_cluster ( + const std::vector& edges, + std::vector& labels, + const double eps = 1e-4, + const unsigned long max_iterations = 2000 + ); + /*! + requires + - for all valid i: + - 0 <= edges[i].distance() < std::numeric_limits::infinity() + ensures + - This function is identical to the above newman_cluster() routine except that + it operates on a vector of sample_pair objects instead of + ordered_sample_pairs. Therefore, this is simply a convenience routine. In + particular, it is implemented by transforming the given edges into + ordered_sample_pairs and then calling the newman_cluster() routine defined + above. + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_MODULARITY_ClUSTERING_ABSTRACT_Hh_ + diff --git a/dlib/clustering/spectral_cluster.h b/dlib/clustering/spectral_cluster.h new file mode 100644 index 0000000000000000000000000000000000000000..2cac9870fc86f0af667a05bca5cadc1bccd51843 --- /dev/null +++ b/dlib/clustering/spectral_cluster.h @@ -0,0 +1,80 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_SPECTRAL_CLUSTEr_H_ +#define DLIB_SPECTRAL_CLUSTEr_H_ + +#include "spectral_cluster_abstract.h" +#include +#include "../matrix.h" +#include "../svm/kkmeans.h" + +namespace dlib +{ + template < + typename kernel_type, + typename vector_type + > + std::vector spectral_cluster ( + const kernel_type& k, + const vector_type& samples, + const unsigned long num_clusters + ) + { + DLIB_CASSERT(num_clusters > 0, + "\t std::vector spectral_cluster(k,samples,num_clusters)" + << "\n\t num_clusters can't be 0." + ); + + if (num_clusters == 1) + { + // nothing to do, just assign everything to the 0 cluster. + return std::vector(samples.size(), 0); + } + + // compute the similarity matrix. + matrix K(samples.size(), samples.size()); + for (long r = 0; r < K.nr(); ++r) + for (long c = r+1; c < K.nc(); ++c) + K(r,c) = K(c,r) = (double)k(samples[r], samples[c]); + for (long r = 0; r < K.nr(); ++r) + K(r,r) = 0; + + matrix D(K.nr()); + for (long r = 0; r < K.nr(); ++r) + D(r) = sum(rowm(K,r)); + D = sqrt(reciprocal(D)); + K = diagm(D)*K*diagm(D); + matrix u,w,v; + // Use the normal SVD routine unless the matrix is really big, then use the fast + // approximate version. + if (K.nr() < 1000) + svd3(K,u,w,v); + else + svd_fast(K,u,w,v, num_clusters+100, 5); + // Pick out the eigenvectors associated with the largest eigenvalues. + rsort_columns(v,w); + v = colm(v, range(0,num_clusters-1)); + // Now build the normalized spectral vectors, one for each input vector. + std::vector > spec_samps, centers; + for (long r = 0; r < v.nr(); ++r) + { + spec_samps.push_back(trans(rowm(v,r))); + const double len = length(spec_samps.back()); + if (len != 0) + spec_samps.back() /= len; + } + // Finally do the K-means clustering + pick_initial_centers(num_clusters, centers, spec_samps); + find_clusters_using_kmeans(spec_samps, centers); + // And then compute the cluster assignments based on the output of K-means. + std::vector assignments; + for (unsigned long i = 0; i < spec_samps.size(); ++i) + assignments.push_back(nearest_center(centers, spec_samps[i])); + + return assignments; + } + +} + +#endif // DLIB_SPECTRAL_CLUSTEr_H_ + diff --git a/dlib/clustering/spectral_cluster_abstract.h b/dlib/clustering/spectral_cluster_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..880ad80afb10b2b0775491119a39c66e128be6d9 --- /dev/null +++ b/dlib/clustering/spectral_cluster_abstract.h @@ -0,0 +1,43 @@ +// Copyright (C) 2015 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ +#ifdef DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ + +#include + +namespace dlib +{ + template < + typename kernel_type, + typename vector_type + > + std::vector spectral_cluster ( + const kernel_type& k, + const vector_type& samples, + const unsigned long num_clusters + ); + /*! + requires + - samples must be something with an interface compatible with std::vector. + - The following expression must evaluate to a double or float: + k(samples[i], samples[j]) + - num_clusters > 0 + ensures + - Performs the spectral clustering algorithm described in the paper: + On spectral clustering: Analysis and an algorithm by Ng, Jordan, and Weiss. + and returns the results. + - This function clusters the input data samples into num_clusters clusters and + returns a vector that indicates which cluster each sample falls into. In + particular, we return an array A such that: + - A.size() == samples.size() + - A[i] == the cluster assignment of samples[i]. + - for all valid i: 0 <= A[i] < num_clusters + - The "similarity" of samples[i] with samples[j] is given by + k(samples[i],samples[j]). This means that k() should output a number >= 0 + and the number should be larger for samples that are more similar. + !*/ +} + +#endif // DLIB_SPECTRAL_CLUSTEr_ABSTRACT_H_ + + diff --git a/dlib/cmake b/dlib/cmake new file mode 100644 index 0000000000000000000000000000000000000000..224ba3a49160d962a1e1dfa65f1ed85af1bb2cde --- /dev/null +++ b/dlib/cmake @@ -0,0 +1,5 @@ + +cmake_minimum_required(VERSION 3.8.0) + +add_subdirectory(${CMAKE_CURRENT_LIST_DIR} dlib_build) + diff --git a/dlib/cmake_utils/check_if_avx_instructions_executable_on_host.cmake b/dlib/cmake_utils/check_if_avx_instructions_executable_on_host.cmake new file mode 100644 index 0000000000000000000000000000000000000000..4f2cfef933f379d8b7038f1016ed02b7daadca56 --- /dev/null +++ b/dlib/cmake_utils/check_if_avx_instructions_executable_on_host.cmake @@ -0,0 +1,19 @@ +# This script checks if your compiler and host processor can generate and then run programs with AVX instructions. + +cmake_minimum_required(VERSION 3.8.0) + +# Don't rerun this script if its already been executed. +if (DEFINED AVX_IS_AVAILABLE_ON_HOST) + return() +endif() + +# Set to false unless we find out otherwise in the code below. +set(AVX_IS_AVAILABLE_ON_HOST 0) + +try_compile(test_for_avx_worked ${PROJECT_BINARY_DIR}/avx_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_avx + avx_test) + +if(test_for_avx_worked) + message (STATUS "AVX instructions can be executed by the host processor.") + set(AVX_IS_AVAILABLE_ON_HOST 1) +endif() diff --git a/dlib/cmake_utils/check_if_neon_available.cmake b/dlib/cmake_utils/check_if_neon_available.cmake new file mode 100644 index 0000000000000000000000000000000000000000..895c810b74d71e288078e90a52068342965a951c --- /dev/null +++ b/dlib/cmake_utils/check_if_neon_available.cmake @@ -0,0 +1,20 @@ +# This script checks if __ARM_NEON__ is defined for your compiler + +cmake_minimum_required(VERSION 3.8.0) + +# Don't rerun this script if its already been executed. +if (DEFINED ARM_NEON_IS_AVAILABLE) + return() +endif() + +# Set to false unless we find out otherwise in the code below. +set(ARM_NEON_IS_AVAILABLE 0) + +# test if __ARM_NEON__ is defined +try_compile(test_for_neon_worked ${PROJECT_BINARY_DIR}/neon_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_neon + neon_test) + +if(test_for_neon_worked) + message (STATUS "__ARM_NEON__ defined.") + set(ARM_NEON_IS_AVAILABLE 1) +endif() diff --git a/dlib/cmake_utils/check_if_sse4_instructions_executable_on_host.cmake b/dlib/cmake_utils/check_if_sse4_instructions_executable_on_host.cmake new file mode 100644 index 0000000000000000000000000000000000000000..c47560997002166e6f965b7f13c8d2b0b7e9d611 --- /dev/null +++ b/dlib/cmake_utils/check_if_sse4_instructions_executable_on_host.cmake @@ -0,0 +1,19 @@ +# This script checks if your compiler and host processor can generate and then run programs with SSE4 instructions. + +cmake_minimum_required(VERSION 3.8.0) + +# Don't rerun this script if its already been executed. +if (DEFINED SSE4_IS_AVAILABLE_ON_HOST) + return() +endif() + +# Set to false unless we find out otherwise in the code below. +set(SSE4_IS_AVAILABLE_ON_HOST 0) + +try_compile(test_for_sse4_worked ${PROJECT_BINARY_DIR}/sse4_test_build ${CMAKE_CURRENT_LIST_DIR}/test_for_sse4 + sse4_test) + +if(test_for_sse4_worked) + message (STATUS "SSE4 instructions can be executed by the host processor.") + set(SSE4_IS_AVAILABLE_ON_HOST 1) +endif() diff --git a/dlib/cmake_utils/dlib.pc.in b/dlib/cmake_utils/dlib.pc.in new file mode 100644 index 0000000000000000000000000000000000000000..906011033605aba89cba300c81e75066713c6ec6 --- /dev/null +++ b/dlib/cmake_utils/dlib.pc.in @@ -0,0 +1,8 @@ +libdir=@CMAKE_INSTALL_FULL_LIBDIR@ +includedir=@CMAKE_INSTALL_FULL_INCLUDEDIR@ + +Name: @PROJECT_NAME@ +Description: Numerical and networking C++ library +Version: @VERSION@ +Libs: -L${libdir} -ldlib @pkg_config_dlib_needed_libraries@ +Cflags: -I${includedir} @pkg_config_dlib_needed_includes@ diff --git a/dlib/cmake_utils/dlibConfig.cmake.in b/dlib/cmake_utils/dlibConfig.cmake.in new file mode 100644 index 0000000000000000000000000000000000000000..2667a2e718bc9bb8dcb824616a49137c05dcf16a --- /dev/null +++ b/dlib/cmake_utils/dlibConfig.cmake.in @@ -0,0 +1,53 @@ +# =================================================================================== +# The dlib CMake configuration file +# +# ** File generated automatically, do not modify ** +# +# Usage from an external project: +# In your CMakeLists.txt, add these lines: +# +# find_package(dlib REQUIRED) +# target_link_libraries(MY_TARGET_NAME dlib::dlib) +# +# =================================================================================== + + + + +# Our library dependencies (contains definitions for IMPORTED targets) +if(NOT TARGET dlib-shared AND NOT dlib_BINARY_DIR) + # Compute paths + get_filename_component(dlib_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH) + include("${dlib_CMAKE_DIR}/dlib.cmake") + # Check if Threads::Threads target is required and find it if necessary + get_target_property(dlib_deps_threads_check dlib::dlib INTERFACE_LINK_LIBRARIES) + list(FIND dlib_deps_threads_check "Threads::Threads" dlib_deps_threads_idx) + if (${dlib_deps_threads_idx} GREATER -1) + if (NOT TARGET Threads) + find_package(Threads REQUIRED) + endif() + endif() + unset(dlib_deps_threads_idx) + unset(dlib_deps_threads_check) +endif() + +set(dlib_LIBRARIES dlib::dlib) +set(dlib_LIBS dlib::dlib) +set(dlib_INCLUDE_DIRS "@CMAKE_INSTALL_FULL_INCLUDEDIR@" "@dlib_needed_includes@") + +mark_as_advanced(dlib_LIBRARIES) +mark_as_advanced(dlib_LIBS) +mark_as_advanced(dlib_INCLUDE_DIRS) + +# Mark these variables above as deprecated. +function(__deprecated_var var access) + if(access STREQUAL "READ_ACCESS") + message(WARNING "The variable '${var}' is deprecated! Instead, simply use target_link_libraries(your_app dlib::dlib). See http://dlib.net/examples/CMakeLists.txt.html for an example.") + endif() +endfunction() +variable_watch(dlib_LIBRARIES __deprecated_var) +variable_watch(dlib_LIBS __deprecated_var) +variable_watch(dlib_INCLUDE_DIRS __deprecated_var) + + + diff --git a/dlib/cmake_utils/find_blas.cmake b/dlib/cmake_utils/find_blas.cmake new file mode 100644 index 0000000000000000000000000000000000000000..9a56d7c22ec236d955172c352dd45ee8ac1f1e50 --- /dev/null +++ b/dlib/cmake_utils/find_blas.cmake @@ -0,0 +1,480 @@ +# +# This is a CMake makefile. You can find the cmake utility and +# information about it at http://www.cmake.org +# +# +# This cmake file tries to find installed BLAS and LAPACK libraries. +# It looks for an installed copy of the Intel MKL library first and then +# attempts to find some other BLAS and LAPACK libraries if you don't have +# the Intel MKL. +# +# blas_found - True if BLAS is available +# lapack_found - True if LAPACK is available +# found_intel_mkl - True if the Intel MKL library is available +# found_intel_mkl_headers - True if Intel MKL headers are available +# blas_libraries - link against these to use BLAS library +# lapack_libraries - link against these to use LAPACK library +# mkl_libraries - link against these to use the MKL library +# mkl_include_dir - add to the include path to use the MKL library +# openmp_libraries - Set to Intel's OpenMP library if and only if we +# find the MKL. + +# setting this makes CMake allow normal looking if else statements +SET(CMAKE_ALLOW_LOOSE_LOOP_CONSTRUCTS true) + +SET(blas_found 0) +SET(lapack_found 0) +SET(found_intel_mkl 0) +SET(found_intel_mkl_headers 0) +SET(lapack_with_underscore 0) +SET(lapack_without_underscore 0) + +message(STATUS "Searching for BLAS and LAPACK") +INCLUDE(CheckFunctionExists) + +if (UNIX OR MINGW) + message(STATUS "Searching for BLAS and LAPACK") + + if (BUILDING_MATLAB_MEX_FILE) + # # This commented out stuff would link directly to MATLAB's built in + # BLAS and LAPACK. But it's better to not link to anything and do a + #find_library(MATLAB_BLAS_LIBRARY mwblas PATHS ${MATLAB_LIB_FOLDERS} ) + #find_library(MATLAB_LAPACK_LIBRARY mwlapack PATHS ${MATLAB_LIB_FOLDERS} ) + #if (MATLAB_BLAS_LIBRARY AND MATLAB_LAPACK_LIBRARY) + # add_subdirectory(external/cblas) + # set(blas_libraries ${MATLAB_BLAS_LIBRARY} cblas ) + # set(lapack_libraries ${MATLAB_LAPACK_LIBRARY} ) + # set(blas_found 1) + # set(lapack_found 1) + # message(STATUS "Found MATLAB's BLAS and LAPACK libraries") + #endif() + + # We need cblas since MATLAB doesn't provide cblas symbols. + add_subdirectory(external/cblas) + set(blas_libraries cblas ) + set(blas_found 1) + set(lapack_found 1) + message(STATUS "Will link with MATLAB's BLAS and LAPACK at runtime (hopefully!)") + + + ## Don't try to link to anything other than MATLAB's own internal blas + ## and lapack libraries because doing so generally upsets MATLAB. So + ## we just end here no matter what. + return() + endif() + + + # First, search for libraries via pkg-config, which is the cleanest path + find_package(PkgConfig) + pkg_check_modules(BLAS_REFERENCE cblas) + pkg_check_modules(LAPACK_REFERENCE lapack) + # Make sure the cblas found by pkgconfig actually has cblas symbols. + SET(CMAKE_REQUIRED_LIBRARIES "${BLAS_REFERENCE_LDFLAGS}") + CHECK_FUNCTION_EXISTS(cblas_ddot PKGCFG_HAVE_CBLAS) + if (BLAS_REFERENCE_FOUND AND LAPACK_REFERENCE_FOUND AND PKGCFG_HAVE_CBLAS) + set(blas_libraries "${BLAS_REFERENCE_LDFLAGS}") + set(lapack_libraries "${LAPACK_REFERENCE_LDFLAGS}") + set(blas_found 1) + set(lapack_found 1) + set(REQUIRES_LIBS "${REQUIRES_LIBS} cblas lapack") + message(STATUS "Found BLAS and LAPACK via pkg-config") + return() + endif() + + include(CheckTypeSize) + check_type_size( "void*" SIZE_OF_VOID_PTR) + + if (SIZE_OF_VOID_PTR EQUAL 8) + set( mkl_search_path + /opt/intel/oneapi/mkl/latest/lib/intel64 + /opt/intel/mkl/*/lib/em64t + /opt/intel/mkl/lib/intel64 + /opt/intel/lib/intel64 + /opt/intel/mkl/lib + /opt/intel/tbb/*/lib/em64t/gcc4.7 + /opt/intel/tbb/lib/intel64/gcc4.7 + /opt/intel/tbb/lib/gcc4.7 + ) + + find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path}) + mark_as_advanced(mkl_intel) + else() + set( mkl_search_path + /opt/intel/oneapi/mkl/latest/lib/ia32 + /opt/intel/mkl/*/lib/32 + /opt/intel/mkl/lib/ia32 + /opt/intel/lib/ia32 + /opt/intel/tbb/*/lib/32/gcc4.7 + /opt/intel/tbb/lib/ia32/gcc4.7 + ) + + find_library(mkl_intel mkl_intel ${mkl_search_path}) + mark_as_advanced(mkl_intel) + endif() + + include(CheckLibraryExists) + + # Get mkl_include_dir + set(mkl_include_search_path + /opt/intel/oneapi/mkl/latest/include + /opt/intel/mkl/include + /opt/intel/include + ) + find_path(mkl_include_dir mkl_version.h ${mkl_include_search_path}) + mark_as_advanced(mkl_include_dir) + + if(NOT DLIB_USE_MKL_SEQUENTIAL AND NOT DLIB_USE_MKL_WITH_TBB) + # Search for the needed libraries from the MKL. We will try to link against the mkl_rt + # file first since this way avoids linking bugs in some cases. + find_library(mkl_rt mkl_rt ${mkl_search_path}) + find_library(openmp_libraries iomp5 ${mkl_search_path}) + mark_as_advanced(mkl_rt openmp_libraries) + # if we found the MKL + if (mkl_rt) + set(mkl_libraries ${mkl_rt} ) + set(blas_libraries ${mkl_rt} ) + set(lapack_libraries ${mkl_rt} ) + set(blas_found 1) + set(lapack_found 1) + set(found_intel_mkl 1) + message(STATUS "Found Intel MKL BLAS/LAPACK library") + endif() + endif() + + + if (NOT found_intel_mkl) + # Search for the needed libraries from the MKL. This time try looking for a different + # set of MKL files and try to link against those. + find_library(mkl_core mkl_core ${mkl_search_path}) + set(mkl_libs ${mkl_intel} ${mkl_core}) + mark_as_advanced(mkl_libs mkl_intel mkl_core) + + if (DLIB_USE_MKL_WITH_TBB) + find_library(mkl_tbb_thread mkl_tbb_thread ${mkl_search_path}) + find_library(mkl_tbb tbb ${mkl_search_path}) + mark_as_advanced(mkl_tbb_thread mkl_tbb) + list(APPEND mkl_libs ${mkl_tbb_thread} ${mkl_tbb}) + elseif (DLIB_USE_MKL_SEQUENTIAL) + find_library(mkl_sequential mkl_sequential ${mkl_search_path}) + mark_as_advanced(mkl_sequential) + list(APPEND mkl_libs ${mkl_sequential}) + else() + find_library(mkl_thread mkl_intel_thread ${mkl_search_path}) + find_library(mkl_iomp iomp5 ${mkl_search_path}) + find_library(mkl_pthread pthread ${mkl_search_path}) + mark_as_advanced(mkl_thread mkl_iomp mkl_pthread) + list(APPEND mkl_libs ${mkl_thread} ${mkl_iomp} ${mkl_pthread}) + endif() + + # If we found the MKL + if (mkl_intel AND mkl_core AND ((mkl_tbb_thread AND mkl_tbb) OR (mkl_thread AND mkl_iomp AND mkl_pthread) OR mkl_sequential)) + set(mkl_libraries ${mkl_libs}) + set(blas_libraries ${mkl_libs}) + set(lapack_libraries ${mkl_libs}) + set(blas_found 1) + set(lapack_found 1) + set(found_intel_mkl 1) + message(STATUS "Found Intel MKL BLAS/LAPACK library") + endif() + endif() + + if (found_intel_mkl AND mkl_include_dir) + set(found_intel_mkl_headers 1) + endif() + + # try to find some other LAPACK libraries if we didn't find the MKL + set(extra_paths + /usr/lib64 + /usr/lib64/atlas-sse3 + /usr/lib64/atlas-sse2 + /usr/lib64/atlas + /usr/lib + /usr/lib/atlas-sse3 + /usr/lib/atlas-sse2 + /usr/lib/atlas + /usr/lib/openblas-base + /opt/OpenBLAS/lib + $ENV{OPENBLAS_HOME}/lib + ) + + if (NOT blas_found) + find_library(cblas_lib NAMES openblasp openblas PATHS ${extra_paths}) + if (cblas_lib) + set(blas_libraries ${cblas_lib}) + set(blas_found 1) + message(STATUS "Found OpenBLAS library") + set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) + # If you compiled OpenBLAS with LAPACK in it then it should have the + # sgetrf_single function in it. So if we find that function in + # OpenBLAS then just use OpenBLAS's LAPACK. + CHECK_FUNCTION_EXISTS(sgetrf_single OPENBLAS_HAS_LAPACK) + if (OPENBLAS_HAS_LAPACK) + message(STATUS "Using OpenBLAS's built in LAPACK") + # set(lapack_libraries gfortran) + set(lapack_found 1) + endif() + endif() + mark_as_advanced( cblas_lib) + endif() + + + if (NOT lapack_found) + find_library(lapack_lib NAMES lapack lapack-3 PATHS ${extra_paths}) + if (lapack_lib) + set(lapack_libraries ${lapack_lib}) + set(lapack_found 1) + message(STATUS "Found LAPACK library") + endif() + mark_as_advanced( lapack_lib) + endif() + + + # try to find some other BLAS libraries if we didn't find the MKL + + if (NOT blas_found) + find_library(atlas_lib atlas PATHS ${extra_paths}) + find_library(cblas_lib cblas PATHS ${extra_paths}) + if (atlas_lib AND cblas_lib) + set(blas_libraries ${atlas_lib} ${cblas_lib}) + set(blas_found 1) + message(STATUS "Found ATLAS BLAS library") + endif() + mark_as_advanced( atlas_lib cblas_lib) + endif() + + # CentOS 7 atlas + if (NOT blas_found) + find_library(tatlas_lib tatlas PATHS ${extra_paths}) + find_library(satlas_lib satlas PATHS ${extra_paths}) + if (tatlas_lib AND satlas_lib ) + set(blas_libraries ${tatlas_lib} ${satlas_lib}) + set(blas_found 1) + message(STATUS "Found ATLAS BLAS library") + endif() + mark_as_advanced( tatlas_lib satlas_lib) + endif() + + + if (NOT blas_found) + find_library(cblas_lib cblas PATHS ${extra_paths}) + if (cblas_lib) + set(blas_libraries ${cblas_lib}) + set(blas_found 1) + message(STATUS "Found CBLAS library") + endif() + mark_as_advanced( cblas_lib) + endif() + + + if (NOT blas_found) + find_library(generic_blas blas PATHS ${extra_paths}) + if (generic_blas) + set(blas_libraries ${generic_blas}) + set(blas_found 1) + message(STATUS "Found BLAS library") + endif() + mark_as_advanced( generic_blas) + endif() + + + + + # Make sure we really found a CBLAS library. That is, it needs to expose + # the proper cblas link symbols. So here we test if one of them is present + # and assume everything is good if it is. Note that we don't do this check if + # we found the Intel MKL since for some reason CHECK_FUNCTION_EXISTS doesn't work + # with it. But it's fine since the MKL should always have cblas. + if (blas_found AND NOT found_intel_mkl) + set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) + CHECK_FUNCTION_EXISTS(cblas_ddot FOUND_BLAS_HAS_CBLAS) + if (NOT FOUND_BLAS_HAS_CBLAS) + message(STATUS "BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK") + set(blas_found 0) + set(lapack_found 0) + endif() + endif() + + + +elseif(WIN32 AND NOT MINGW) + message(STATUS "Searching for BLAS and LAPACK") + + include(CheckTypeSize) + check_type_size( "void*" SIZE_OF_VOID_PTR) + if (SIZE_OF_VOID_PTR EQUAL 8) + set( mkl_search_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/intel64" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/tbb/lib/intel64/vc14" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/intel64" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/intel64" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/intel64/vc14" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/intel64/vc_mt" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/intel64" + "C:/Program Files (x86)/Intel/Composer XE/mkl/lib/intel64" + "C:/Program Files (x86)/Intel/Composer XE/tbb/lib/intel64/vc14" + "C:/Program Files (x86)/Intel/Composer XE/compiler/lib/intel64" + "C:/Program Files/Intel/Composer XE/mkl/lib/intel64" + "C:/Program Files/Intel/Composer XE/tbb/lib/intel64/vc14" + "C:/Program Files/Intel/Composer XE/compiler/lib/intel64" + "C:/Program Files (x86)/Intel/oneAPI/mkl/*/lib/intel64" + "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/compiler/lib/intel64_win" + ) + set (mkl_redist_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/redist/intel64/compiler" + "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/redist/intel64_win/compiler" + ) + find_library(mkl_intel mkl_intel_lp64 ${mkl_search_path}) + else() + set( mkl_search_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/lib/ia32" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/tbb/lib/ia32/vc14" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/lib/ia32" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/lib/ia32" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/ia32/vc14" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/tbb/lib/ia32/vc_mt" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/lib/ia32" + "C:/Program Files (x86)/Intel/Composer XE/mkl/lib/ia32" + "C:/Program Files (x86)/Intel/Composer XE/tbb/lib/ia32/vc14" + "C:/Program Files (x86)/Intel/Composer XE/compiler/lib/ia32" + "C:/Program Files/Intel/Composer XE/mkl/lib/ia32" + "C:/Program Files/Intel/Composer XE/tbb/lib/ia32/vc14" + "C:/Program Files/Intel/Composer XE/compiler/lib/ia32" + "C:/Program Files (x86)/Intel/oneAPI/mkl/*/lib/ia32" + "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/compiler/lib/ia32_win" + + ) + set (mkl_redist_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/redist/ia32/compiler" + "C:/Program Files (x86)/Intel/oneAPI/compiler/*/windows/redist/ia32_win/compiler" + ) + find_library(mkl_intel mkl_intel_c ${mkl_search_path}) + endif() + + + # Get mkl_include_dir + set(mkl_include_search_path + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/mkl/include" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries_*/windows/compiler/include" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/mkl/include" + "C:/Program Files (x86)/IntelSWTools/compilers_and_libraries/windows/compiler/include" + "C:/Program Files (x86)/Intel/Composer XE/mkl/include" + "C:/Program Files (x86)/Intel/Composer XE/compiler/include" + "C:/Program Files/Intel/Composer XE/mkl/include" + "C:/Program Files/Intel/Composer XE/compiler/include" + "C:/Program Files (x86)/Intel/oneAPI/mkl/*/include" + ) + find_path(mkl_include_dir mkl_version.h ${mkl_include_search_path}) + mark_as_advanced(mkl_include_dir) + + # Search for the needed libraries from the MKL. + find_library(mkl_core mkl_core ${mkl_search_path}) + set(mkl_libs ${mkl_intel} ${mkl_core}) + mark_as_advanced(mkl_libs mkl_intel mkl_core) + if (DLIB_USE_MKL_WITH_TBB) + find_library(mkl_tbb_thread mkl_tbb_thread ${mkl_search_path}) + find_library(mkl_tbb tbb ${mkl_search_path}) + mark_as_advanced(mkl_tbb_thread mkl_tbb) + list(APPEND mkl_libs ${mkl_tbb_thread} ${mkl_tbb}) + elseif (DLIB_USE_MKL_SEQUENTIAL) + find_library(mkl_sequential mkl_sequential ${mkl_search_path}) + mark_as_advanced(mkl_sequential) + list(APPEND mkl_libs ${mkl_sequential}) + else() + find_library(mkl_thread mkl_intel_thread ${mkl_search_path}) + mark_as_advanced(mkl_thread) + if (mkl_thread) + find_library(mkl_iomp libiomp5md ${mkl_search_path}) + mark_as_advanced(mkl_iomp) + list(APPEND mkl_libs ${mkl_thread} ${mkl_iomp}) + + # See if we can find the dll that goes with this, so we can copy it to + # the output folder, since a very large number of windows users don't + # understand that they need to add the Intel MKL's folders to their + # PATH to use the Intel MKL. They then complain on the dlib forums. + # Copying the Intel MKL dlls to the output directory removes the need + # to add the Intel MKL to the PATH. + find_file(mkl_iomp_dll "libiomp5md.dll" ${mkl_redist_path}) + if (mkl_iomp_dll) + message(STATUS "FOUND libiomp5md.dll: ${mkl_iomp_dll}") + endif() + endif() + endif() + + # If we found the MKL + if (mkl_intel AND mkl_core AND ((mkl_tbb_thread AND mkl_tbb) OR mkl_sequential OR (mkl_thread AND mkl_iomp))) + set(blas_libraries ${mkl_libs}) + set(lapack_libraries ${mkl_libs}) + set(blas_found 1) + set(lapack_found 1) + set(found_intel_mkl 1) + message(STATUS "Found Intel MKL BLAS/LAPACK library") + + # Make sure the version of the Intel MKL we found is compatible with + # the compiler we are using. One way to do this check is to see if we can + # link to it right now. + set(CMAKE_REQUIRED_LIBRARIES ${blas_libraries}) + CHECK_FUNCTION_EXISTS(cblas_ddot MKL_HAS_CBLAS) + if (NOT MKL_HAS_CBLAS) + message("BLAS library does not have cblas symbols, so dlib will not use BLAS or LAPACK") + set(blas_found 0) + set(lapack_found 0) + endif() + endif() + + if (found_intel_mkl AND mkl_include_dir) + set(found_intel_mkl_headers 1) + endif() + +endif() + + +# When all else fails use CMake's built in functions to find BLAS and LAPACK +if (NOT blas_found) + find_package(BLAS QUIET) + if (${BLAS_FOUND}) + set(blas_libraries ${BLAS_LIBRARIES}) + set(blas_found 1) + if (NOT lapack_found) + find_package(LAPACK QUIET) + if (${LAPACK_FOUND}) + set(lapack_libraries ${LAPACK_LIBRARIES}) + set(lapack_found 1) + endif() + endif() + endif() +endif() + + +# If using lapack, determine whether to mangle functions +if (lapack_found) + include(CheckFortranFunctionExists) + set(CMAKE_REQUIRED_LIBRARIES ${lapack_libraries}) + + check_function_exists("sgesv" LAPACK_FOUND_C_UNMANGLED) + check_function_exists("sgesv_" LAPACK_FOUND_C_MANGLED) + if (CMAKE_Fortran_COMPILER_LOADED) + check_fortran_function_exists("sgesv" LAPACK_FOUND_FORTRAN_UNMANGLED) + check_fortran_function_exists("sgesv_" LAPACK_FOUND_FORTRAN_MANGLED) + endif () + if (LAPACK_FOUND_C_MANGLED OR LAPACK_FOUND_FORTRAN_MANGLED) + set(lapack_with_underscore 1) + elseif (LAPACK_FOUND_C_UNMANGLED OR LAPACK_FOUND_FORTRAN_UNMANGLED) + set(lapack_without_underscore 1) + endif () +endif() + + +if (UNIX OR MINGW) + if (NOT blas_found) + message(" *****************************************************************************") + message(" *** No BLAS library found so using dlib's built in BLAS. However, if you ***") + message(" *** install an optimized BLAS such as OpenBLAS or the Intel MKL your code ***") + message(" *** will run faster. On Ubuntu you can install OpenBLAS by executing: ***") + message(" *** sudo apt-get install libopenblas-dev liblapack-dev ***") + message(" *** Or you can easily install OpenBLAS from source by downloading the ***") + message(" *** source tar file from http://www.openblas.net, extracting it, and ***") + message(" *** running: ***") + message(" *** make; sudo make install ***") + message(" *****************************************************************************") + endif() +endif() diff --git a/dlib/cmake_utils/find_ffmpeg.cmake b/dlib/cmake_utils/find_ffmpeg.cmake new file mode 100644 index 0000000000000000000000000000000000000000..89abd8d6b9fca2947b61d56687e5916f0ffed819 --- /dev/null +++ b/dlib/cmake_utils/find_ffmpeg.cmake @@ -0,0 +1,30 @@ +cmake_minimum_required(VERSION 3.8.0) + +message(STATUS "Searching for FFMPEG/LIBAV") +find_package(PkgConfig REQUIRED) + +if (PkgConfig_FOUND) + pkg_check_modules(FFMPEG IMPORTED_TARGET + libavdevice + libavfilter + libavformat + libavcodec + libavutil + libswresample + libswscale + ) + if (FFMPEG_FOUND) + message(STATUS "Found FFMPEG/LIBAV via pkg-config in `${FFMPEG_LIBRARY_DIRS}`") + else() + message(" *****************************************************************************") + message(" *** No FFMPEG/LIBAV libraries found. ***") + message(" *** On Ubuntu you can install them by executing ***") + message(" *** sudo apt install libavdevice-dev libavfilter-dev libavformat-dev ***") + message(" *** sudo apt install libavcodec-dev libswresample-dev libswscale-dev ***") + message(" *** sudo apt install libavutil-dev ***") + message(" *****************************************************************************") + endif() +else() + message(STATUS "PkgConfig could not be found, FFMPEG won't be available") + set(FFMPEG_FOUND 0) +endif() diff --git a/dlib/cmake_utils/find_libjpeg.cmake b/dlib/cmake_utils/find_libjpeg.cmake new file mode 100644 index 0000000000000000000000000000000000000000..028217b0752ce52e6f75d262b384efa689432f41 --- /dev/null +++ b/dlib/cmake_utils/find_libjpeg.cmake @@ -0,0 +1,38 @@ +#This script just runs CMake's built in JPEG finding tool. But it also checks that the +#copy of libjpeg that cmake finds actually builds and links. + +cmake_minimum_required(VERSION 3.8.0) + +if (BUILDING_PYTHON_IN_MSVC) + # Never use any system copy of libjpeg when building python in visual studio + set(JPEG_FOUND 0) + return() +endif() + +# Don't rerun this script if its already been executed. +if (DEFINED JPEG_FOUND) + return() +endif() + +find_package(JPEG QUIET) + +if(JPEG_FOUND) + set(JPEG_TEST_CMAKE_FLAGS + "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" + "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" + "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") + + try_compile(test_for_libjpeg_worked + ${PROJECT_BINARY_DIR}/test_for_libjpeg_build + ${CMAKE_CURRENT_LIST_DIR}/test_for_libjpeg + test_if_libjpeg_is_broken + CMAKE_FLAGS "${JPEG_TEST_CMAKE_FLAGS}") + + message (STATUS "Found system copy of libjpeg: ${JPEG_LIBRARY}") + if(NOT test_for_libjpeg_worked) + set(JPEG_FOUND 0) + message (STATUS "System copy of libjpeg is broken or too old. Will build our own libjpeg and use that instead.") + endif() +endif() + + diff --git a/dlib/cmake_utils/find_libpng.cmake b/dlib/cmake_utils/find_libpng.cmake new file mode 100644 index 0000000000000000000000000000000000000000..676073922675681a606245d7326b4c4660ffba8d --- /dev/null +++ b/dlib/cmake_utils/find_libpng.cmake @@ -0,0 +1,37 @@ +#This script just runs CMake's built in PNG finding tool. But it also checks that the +#copy of libpng that cmake finds actually builds and links. + +cmake_minimum_required(VERSION 3.8.0) + +if (BUILDING_PYTHON_IN_MSVC) + # Never use any system copy of libpng when building python in visual studio + set(PNG_FOUND 0) + return() +endif() + +# Don't rerun this script if its already been executed. +if (DEFINED PNG_FOUND) + return() +endif() + +find_package(PNG QUIET) + +if(PNG_FOUND) + set(PNG_TEST_CMAKE_FLAGS + "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" + "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" + "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") + + try_compile(test_for_libpng_worked + ${PROJECT_BINARY_DIR}/test_for_libpng_build + ${CMAKE_CURRENT_LIST_DIR}/test_for_libpng + test_if_libpng_is_broken + CMAKE_FLAGS "${PNG_TEST_CMAKE_FLAGS}") + + message (STATUS "Found system copy of libpng: ${PNG_LIBRARIES}") + if(NOT test_for_libpng_worked) + set(PNG_FOUND 0) + message (STATUS "System copy of libpng is broken. Will build our own libpng and use that instead.") + endif() +endif() + diff --git a/dlib/cmake_utils/find_libwebp.cmake b/dlib/cmake_utils/find_libwebp.cmake new file mode 100644 index 0000000000000000000000000000000000000000..022528592634856e2780694fe13a9c3a802798ec --- /dev/null +++ b/dlib/cmake_utils/find_libwebp.cmake @@ -0,0 +1,52 @@ +#============================================================================= +# Find WebP library +# From OpenCV +#============================================================================= +# Find the native WebP headers and libraries. +# +# WEBP_INCLUDE_DIRS - where to find webp/decode.h, etc. +# WEBP_LIBRARIES - List of libraries when using webp. +# WEBP_FOUND - True if webp is found. +#============================================================================= + +# Look for the header file. + +unset(WEBP_FOUND) + +find_path(WEBP_INCLUDE_DIR NAMES webp/decode.h) + +if(NOT WEBP_INCLUDE_DIR) + unset(WEBP_FOUND) +else() + mark_as_advanced(WEBP_INCLUDE_DIR) + + # Look for the library. + find_library(WEBP_LIBRARY NAMES webp) + mark_as_advanced(WEBP_LIBRARY) + + # handle the QUIETLY and REQUIRED arguments and set WEBP_FOUND to TRUE if + # all listed variables are TRUE + include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake) + find_package_handle_standard_args(WebP DEFAULT_MSG WEBP_LIBRARY WEBP_INCLUDE_DIR) + + set(WEBP_LIBRARIES ${WEBP_LIBRARY}) + set(WEBP_INCLUDE_DIRS ${WEBP_INCLUDE_DIR}) +endif() + +if(WEBP_FOUND) + set(WEBP_TEST_CMAKE_FLAGS + "-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}" + "-DCMAKE_INCLUDE_PATH=${CMAKE_INCLUDE_PATH}" + "-DCMAKE_LIBRARY_PATH=${CMAKE_LIBRARY_PATH}") + + try_compile(test_for_libwebp_worked + ${PROJECT_BINARY_DIR}/test_for_libwebp_build + ${CMAKE_CURRENT_LIST_DIR}/test_for_libwebp + test_if_libwebp_is_broken + CMAKE_FLAGS "${WEBP_TEST_CMAKE_FLAGS}") + + if(NOT test_for_libwebp_worked) + set(WEBP_FOUND 0) + message (STATUS "System copy of libwebp is either too old or broken. Will disable WebP support.") + endif() +endif() diff --git a/dlib/cmake_utils/release_build_by_default b/dlib/cmake_utils/release_build_by_default new file mode 100644 index 0000000000000000000000000000000000000000..1b0e958317c7c07af958f6c9e54ae84f682eb47f --- /dev/null +++ b/dlib/cmake_utils/release_build_by_default @@ -0,0 +1,9 @@ + +#set default build type to Release +if(NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRING + "Choose the type of build, options are: Debug Release + RelWithDebInfo MinSizeRel." FORCE) +endif() + + diff --git a/dlib/cmake_utils/set_compiler_specific_options.cmake b/dlib/cmake_utils/set_compiler_specific_options.cmake new file mode 100644 index 0000000000000000000000000000000000000000..8093ca6d30330444c698b092b08e723c7cbb2ae4 --- /dev/null +++ b/dlib/cmake_utils/set_compiler_specific_options.cmake @@ -0,0 +1,120 @@ + +cmake_minimum_required(VERSION 3.8.0) + + +# Check if we are being built as part of a pybind11 module. +if (COMMAND pybind11_add_module) + # For python users, enable SSE4 and AVX if they have these instructions. + include(${CMAKE_CURRENT_LIST_DIR}/check_if_sse4_instructions_executable_on_host.cmake) + if (SSE4_IS_AVAILABLE_ON_HOST) + set(USE_SSE4_INSTRUCTIONS ON CACHE BOOL "Compile your program with SSE4 instructions") + endif() + include(${CMAKE_CURRENT_LIST_DIR}/check_if_avx_instructions_executable_on_host.cmake) + if (AVX_IS_AVAILABLE_ON_HOST) + set(USE_AVX_INSTRUCTIONS ON CACHE BOOL "Compile your program with AVX instructions") + endif() + include(${CMAKE_CURRENT_LIST_DIR}/check_if_neon_available.cmake) + if (ARM_NEON_IS_AVAILABLE) + set(USE_NEON_INSTRUCTIONS ON CACHE BOOL "Compile your program with ARM-NEON instructions") + endif() +endif() + + + + +set(gcc_like_compilers GNU Clang Intel) +set(intel_archs x86_64 i386 i686 AMD64 amd64 x86) + + +# Setup some options to allow a user to enable SSE and AVX instruction use. +if ((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND + (";${intel_archs};" MATCHES ";${CMAKE_SYSTEM_PROCESSOR};") AND NOT USE_AUTO_VECTOR) + option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" OFF) + option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF) + option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF) + if(USE_AVX_INSTRUCTIONS) + list(APPEND active_compile_opts -mavx) + message(STATUS "Enabling AVX instructions") + elseif (USE_SSE4_INSTRUCTIONS) + list(APPEND active_compile_opts -msse4) + message(STATUS "Enabling SSE4 instructions") + elseif(USE_SSE2_INSTRUCTIONS) + list(APPEND active_compile_opts -msse2) + message(STATUS "Enabling SSE2 instructions") + endif() +elseif (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") # else if using Visual Studio + # Use SSE2 by default when using Visual Studio. + option(USE_SSE2_INSTRUCTIONS "Compile your program with SSE2 instructions" ON) + option(USE_SSE4_INSTRUCTIONS "Compile your program with SSE4 instructions" OFF) + option(USE_AVX_INSTRUCTIONS "Compile your program with AVX instructions" OFF) + + include(CheckTypeSize) + check_type_size( "void*" SIZE_OF_VOID_PTR) + if(USE_AVX_INSTRUCTIONS) + list(APPEND active_compile_opts /arch:AVX) + message(STATUS "Enabling AVX instructions") + elseif (USE_SSE4_INSTRUCTIONS) + # Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes. + # So only give it when we are doing a 32 bit build. + if (SIZE_OF_VOID_PTR EQUAL 4) + list(APPEND active_compile_opts /arch:SSE2) + endif() + message(STATUS "Enabling SSE4 instructions") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE3") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE41") + elseif(USE_SSE2_INSTRUCTIONS) + # Visual studio doesn't have an /arch:SSE2 flag when building in 64 bit modes. + # So only give it when we are doing a 32 bit build. + if (SIZE_OF_VOID_PTR EQUAL 4) + list(APPEND active_compile_opts /arch:SSE2) + endif() + message(STATUS "Enabling SSE2 instructions") + list(APPEND active_preprocessor_switches "-DDLIB_HAVE_SSE2") + endif() + +elseif((";${gcc_like_compilers};" MATCHES ";${CMAKE_CXX_COMPILER_ID};") AND + ("${CMAKE_SYSTEM_PROCESSOR}" MATCHES "^arm")) + option(USE_NEON_INSTRUCTIONS "Compile your program with ARM-NEON instructions" OFF) + if(USE_NEON_INSTRUCTIONS) + list(APPEND active_compile_opts -mfpu=neon) + message(STATUS "Enabling ARM-NEON instructions") + endif() +endif() + + + + +if (CMAKE_COMPILER_IS_GNUCXX) + # By default, g++ won't warn or error if you forget to return a value in a + # function which requires you to do so. This option makes it give a warning + # for doing this. + list(APPEND active_compile_opts "-Wreturn-type") +endif() + +if ("Clang" MATCHES ${CMAKE_CXX_COMPILER_ID} AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0.0) + # Clang 6 had a default template recursion depth of 256. This was changed to 1024 in Clang 7. + # It must be increased on Clang 6 and below to ensure that the dnn examples don't error out. + list(APPEND active_compile_opts "-ftemplate-depth=500") +endif() + +if (MSVC) + # By default Visual Studio does not support .obj files with more than 65k sections. + # However, code generated by file_to_code_ex and code using DNN module can have + # them. So this flag enables > 65k sections, but produces .obj files + # that will not be readable by VS 2005. + list(APPEND active_compile_opts "/bigobj") + + # Build dlib with all cores. Don't propagate the setting to client programs + # though since they might compile large translation units that use too much + # RAM. + list(APPEND active_compile_opts_private "/MP") + + if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 3.3) + # Clang can compile all Dlib's code at Windows platform. Tested with Clang 5 + list(APPEND active_compile_opts -Xclang) + list(APPEND active_compile_opts -fcxx-exceptions) + endif() +endif() + + diff --git a/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake b/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake new file mode 100644 index 0000000000000000000000000000000000000000..80122d8018fdd9f7d2de2b5656522e176df4e1d1 --- /dev/null +++ b/dlib/cmake_utils/tell_visual_studio_to_use_static_runtime.cmake @@ -0,0 +1,19 @@ + +# Including this cmake script into your cmake project will cause visual studio +# to build your project against the static C runtime. + +cmake_minimum_required(VERSION 3.8.0) + +if (MSVC OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") + option (DLIB_FORCE_MSVC_STATIC_RUNTIME "use static runtime" ON) + if (DLIB_FORCE_MSVC_STATIC_RUNTIME) + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/MD") + string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}") + endif() + endforeach(flag_var) + endif () +endif() + diff --git a/dlib/cmake_utils/test_for_avx/CMakeLists.txt b/dlib/cmake_utils/test_for_avx/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..10890153cf680fa580a92aef1b14ec8a9d9f295f --- /dev/null +++ b/dlib/cmake_utils/test_for_avx/CMakeLists.txt @@ -0,0 +1,23 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(avx_test) + +set(USE_AVX_INSTRUCTIONS ON CACHE BOOL "Use AVX instructions") + +# Pull this in since it sets the AVX compile options by putting that kind of stuff into the active_compile_opts list. +include(../set_compiler_specific_options.cmake) + + +try_run(run_result compile_result ${PROJECT_BINARY_DIR}/avx_test_try_run_build ${CMAKE_CURRENT_LIST_DIR}/avx_test.cpp + COMPILE_DEFINITIONS ${active_compile_opts}) + +message(STATUS "run_result = ${run_result}") +message(STATUS "compile_result = ${compile_result}") + +if ("${run_result}" EQUAL 0 AND compile_result) + message(STATUS "Ran AVX test program successfully, you have AVX available.") +else() + message(STATUS "Unable to run AVX test program, you don't seem to have AVX instructions available.") + # make this build fail so that calling try_compile statements will error in this case. + add_library(make_this_build_fail ${CMAKE_CURRENT_LIST_DIR}/this_file_doesnt_compile.cpp) +endif() diff --git a/dlib/cmake_utils/test_for_avx/avx_test.cpp b/dlib/cmake_utils/test_for_avx/avx_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..98e9b0b99000a012a88cbd8f3672d6025a1ba5b9 --- /dev/null +++ b/dlib/cmake_utils/test_for_avx/avx_test.cpp @@ -0,0 +1,13 @@ + +#include + +int main() +{ + __m256 x; + x = _mm256_set1_ps(1.23); + x = _mm256_add_ps(x,x); + return 0; +} + +// ------------------------------------------------------------------------------------ + diff --git a/dlib/cmake_utils/test_for_avx/this_file_doesnt_compile.cpp b/dlib/cmake_utils/test_for_avx/this_file_doesnt_compile.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83f89c0c753db34dbcf82babecfa61a5c88f0cd3 --- /dev/null +++ b/dlib/cmake_utils/test_for_avx/this_file_doesnt_compile.cpp @@ -0,0 +1,3 @@ + +#error "This file doesn't compile!" + diff --git a/dlib/cmake_utils/test_for_libjpeg/CMakeLists.txt b/dlib/cmake_utils/test_for_libjpeg/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a3c6910dbef12c8ceb947f5c5310ed0522f6ced7 --- /dev/null +++ b/dlib/cmake_utils/test_for_libjpeg/CMakeLists.txt @@ -0,0 +1,11 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(test_if_libjpeg_is_broken) + +find_package(JPEG) + +include_directories(${JPEG_INCLUDE_DIR}) +add_executable(libjpeg_test libjpeg_test.cpp) +target_link_libraries(libjpeg_test ${JPEG_LIBRARY}) + + diff --git a/dlib/cmake_utils/test_for_libjpeg/libjpeg_test.cpp b/dlib/cmake_utils/test_for_libjpeg/libjpeg_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3811d0949b5907034b19bfe4692e88aaa484b60a --- /dev/null +++ b/dlib/cmake_utils/test_for_libjpeg/libjpeg_test.cpp @@ -0,0 +1,68 @@ +// Copyright (C) 2019 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include +#include +#include + +struct jpeg_loader_error_mgr +{ + jpeg_error_mgr pub; + jmp_buf setjmp_buffer; +}; + +void jpeg_loader_error_exit (j_common_ptr cinfo) +{ + jpeg_loader_error_mgr* myerr = (jpeg_loader_error_mgr*) cinfo->err; + + longjmp(myerr->setjmp_buffer, 1); +} + +// This code doesn't really make a lot of sense. It's just calling all the libjpeg functions to make +// sure they can be compiled and linked. +int main() +{ + std::cerr << "This program is just for build system testing. Don't actually run it." << std::endl; + abort(); + + FILE *fp = fopen("whatever.jpg", "rb" ); + + jpeg_decompress_struct cinfo; + jpeg_loader_error_mgr jerr; + + cinfo.err = jpeg_std_error(&jerr.pub); + + jerr.pub.error_exit = jpeg_loader_error_exit; + + setjmp(jerr.setjmp_buffer); + + jpeg_create_decompress(&cinfo); + + jpeg_stdio_src(&cinfo, fp); + if (false) { + unsigned char imgbuffer[1234]; + jpeg_mem_src(&cinfo, imgbuffer, sizeof(imgbuffer)); + } + + jpeg_read_header(&cinfo, TRUE); + + jpeg_start_decompress(&cinfo); + + unsigned long height_ = cinfo.output_height; + unsigned long width_ = cinfo.output_width; + unsigned long output_components_ = cinfo.output_components; + + unsigned char* rows[123]; + + while (cinfo.output_scanline < cinfo.output_height) + { + jpeg_read_scanlines(&cinfo, &rows[cinfo.output_scanline], 100); + } + + jpeg_finish_decompress(&cinfo); + jpeg_destroy_decompress(&cinfo); + + fclose( fp ); +} diff --git a/dlib/cmake_utils/test_for_libpng/CMakeLists.txt b/dlib/cmake_utils/test_for_libpng/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..0be92061323718241a6727f8a993f357b28e5335 --- /dev/null +++ b/dlib/cmake_utils/test_for_libpng/CMakeLists.txt @@ -0,0 +1,11 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(test_if_libpng_is_broken) + +find_package(PNG) + +include_directories(${PNG_INCLUDE_DIR}) +add_executable(libpng_test libpng_test.cpp) +target_link_libraries(libpng_test ${PNG_LIBRARIES}) + + diff --git a/dlib/cmake_utils/test_for_libpng/libpng_test.cpp b/dlib/cmake_utils/test_for_libpng/libpng_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..05f439ee94d0abb6e09e71ed3a49bcef1252bd9f --- /dev/null +++ b/dlib/cmake_utils/test_for_libpng/libpng_test.cpp @@ -0,0 +1,54 @@ +// Copyright (C) 2019 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. +#include +#include +#include +#include +#include + +void png_loader_user_error_fn_silent(png_structp png_struct, png_const_charp ) +{ + longjmp(png_jmpbuf(png_struct),1); +} +void png_loader_user_warning_fn_silent(png_structp , png_const_charp ) +{ +} + +// This code doesn't really make a lot of sense. It's just calling all the libpng functions to make +// sure they can be compiled and linked. +int main() +{ + std::cerr << "This program is just for build system testing. Don't actually run it." << std::endl; + abort(); + + png_bytep* row_pointers_; + png_structp png_ptr_; + png_infop info_ptr_; + png_infop end_info_; + + FILE *fp = fopen( "whatever.png", "rb" ); + png_byte sig[8]; + fread( sig, 1, 8, fp ); + png_sig_cmp( sig, 0, 8 ); + png_ptr_ = png_create_read_struct( PNG_LIBPNG_VER_STRING, NULL, &png_loader_user_error_fn_silent, &png_loader_user_warning_fn_silent ); + + png_get_header_ver(NULL); + info_ptr_ = png_create_info_struct( png_ptr_ ); + end_info_ = png_create_info_struct( png_ptr_ ); + setjmp(png_jmpbuf(png_ptr_)); + png_set_palette_to_rgb(png_ptr_); + png_init_io( png_ptr_, fp ); + png_set_sig_bytes( png_ptr_, 8 ); + // flags force one byte per channel output + int png_transforms = PNG_TRANSFORM_PACKING; + png_read_png( png_ptr_, info_ptr_, png_transforms, NULL ); + png_get_image_height( png_ptr_, info_ptr_ ); + png_get_image_width( png_ptr_, info_ptr_ ); + png_get_bit_depth( png_ptr_, info_ptr_ ); + png_get_color_type( png_ptr_, info_ptr_ ); + + png_get_rows( png_ptr_, info_ptr_ ); + + fclose(fp); + png_destroy_read_struct(&png_ptr_, &info_ptr_, &end_info_); +} diff --git a/dlib/cmake_utils/test_for_libwebp/CMakeLists.txt b/dlib/cmake_utils/test_for_libwebp/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..9ea3ed10eac178f9a4b3c046f3a6c95ffde344e9 --- /dev/null +++ b/dlib/cmake_utils/test_for_libwebp/CMakeLists.txt @@ -0,0 +1,7 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(test_if_libwebp_is_broken) + +include_directories(${WEBP_INCLUDE_DIR}) +add_executable(libwebp_test libwebp_test.cpp) +target_link_libraries(libwebp_test ${WEBP_LIBRARY}) diff --git a/dlib/cmake_utils/test_for_libwebp/libwebp_test.cpp b/dlib/cmake_utils/test_for_libwebp/libwebp_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e561b4d4fb46035b6c1ae79343f90e21b7c6536e --- /dev/null +++ b/dlib/cmake_utils/test_for_libwebp/libwebp_test.cpp @@ -0,0 +1,22 @@ +// Copyright (C) 2019 Davis E. King (davis@dlib.net), Nils Labugt +// License: Boost Software License See LICENSE.txt for the full license. + +#include +#include +#include + +// This code doesn't really make a lot of sense. It's just calling all the libjpeg functions to make +// sure they can be compiled and linked. +int main() +{ + std::cerr << "This program is just for build system testing. Don't actually run it." << std::endl; + std::abort(); + + uint8_t* data; + size_t output_size = 0; + int width, height, stride; + float quality; + output_size = WebPEncodeRGB(data, width, height, stride, quality, &data); + WebPDecodeRGBInto(data, output_size, data, output_size, stride); + WebPFree(data); +} diff --git a/dlib/cmake_utils/test_for_neon/CMakeLists.txt b/dlib/cmake_utils/test_for_neon/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..b9c6e51e8ae5dcecb6f7b21028aa1e0b6872c42d --- /dev/null +++ b/dlib/cmake_utils/test_for_neon/CMakeLists.txt @@ -0,0 +1,6 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(neon_test) + +add_library(neon_test STATIC neon_test.cpp ) + diff --git a/dlib/cmake_utils/test_for_neon/neon_test.cpp b/dlib/cmake_utils/test_for_neon/neon_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a4abdbadeee68bef850ec338b34b9a9c00289ed3 --- /dev/null +++ b/dlib/cmake_utils/test_for_neon/neon_test.cpp @@ -0,0 +1,9 @@ +#ifdef __ARM_NEON__ +#else +#error "No NEON" +#endif +int main(){} + + +// ------------------------------------------------------------------------------------ + diff --git a/dlib/cmake_utils/test_for_sse4/CMakeLists.txt b/dlib/cmake_utils/test_for_sse4/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..4cd4e95cef131eff56e6ea2f072888a34e76cb9e --- /dev/null +++ b/dlib/cmake_utils/test_for_sse4/CMakeLists.txt @@ -0,0 +1,23 @@ + +cmake_minimum_required(VERSION 3.8.0) +project(sse4_test) + +set(USE_SSE4_INSTRUCTIONS ON CACHE BOOL "Use SSE4 instructions") + +# Pull this in since it sets the SSE4 compile options by putting that kind of stuff into the active_compile_opts list. +include(../set_compiler_specific_options.cmake) + + +try_run(run_result compile_result ${PROJECT_BINARY_DIR}/sse4_test_try_run_build ${CMAKE_CURRENT_LIST_DIR}/sse4_test.cpp + COMPILE_DEFINITIONS ${active_compile_opts}) + +message(STATUS "run_result = ${run_result}") +message(STATUS "compile_result = ${compile_result}") + +if ("${run_result}" EQUAL 0 AND compile_result) + message(STATUS "Ran SSE4 test program successfully, you have SSE4 available.") +else() + message(STATUS "Unable to run SSE4 test program, you don't seem to have SSE4 instructions available.") + # make this build fail so that calling try_compile statements will error in this case. + add_library(make_this_build_fail ${CMAKE_CURRENT_LIST_DIR}/this_file_doesnt_compile.cpp) +endif() diff --git a/dlib/cmake_utils/test_for_sse4/sse4_test.cpp b/dlib/cmake_utils/test_for_sse4/sse4_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8065dab952e36ad4674b3d21af66642882ecbc70 --- /dev/null +++ b/dlib/cmake_utils/test_for_sse4/sse4_test.cpp @@ -0,0 +1,18 @@ + +#include +#include +#include +#include // SSE3 +#include +#include // SSE4 + +int main() +{ + __m128 x; + x = _mm_set1_ps(1.23); + x = _mm_ceil_ps(x); + return 0; +} + +// ------------------------------------------------------------------------------------ + diff --git a/dlib/cmake_utils/test_for_sse4/this_file_doesnt_compile.cpp b/dlib/cmake_utils/test_for_sse4/this_file_doesnt_compile.cpp new file mode 100644 index 0000000000000000000000000000000000000000..83f89c0c753db34dbcf82babecfa61a5c88f0cd3 --- /dev/null +++ b/dlib/cmake_utils/test_for_sse4/this_file_doesnt_compile.cpp @@ -0,0 +1,3 @@ + +#error "This file doesn't compile!" + diff --git a/dlib/cmd_line_parser.h b/dlib/cmd_line_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..fd1148038dd4276579ad50036acaea27fc73bde1 --- /dev/null +++ b/dlib/cmd_line_parser.h @@ -0,0 +1,84 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSEr_ +#define DLIB_CMD_LINE_PARSEr_ + +#include "cmd_line_parser/cmd_line_parser_kernel_1.h" +#include "cmd_line_parser/cmd_line_parser_kernel_c.h" +#include "cmd_line_parser/cmd_line_parser_print_1.h" +#include "cmd_line_parser/cmd_line_parser_check_1.h" +#include "cmd_line_parser/cmd_line_parser_check_c.h" +#include +#include "cmd_line_parser/get_option.h" + +#include "map.h" +#include "sequence.h" + + + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + template < + typename charT + > + class impl_cmd_line_parser + { + /*! + This class is basically just a big templated typedef for building + a complete command line parser type out of all the parts it needs. + !*/ + + impl_cmd_line_parser() {} + + typedef typename sequence >::kernel_2a sequence_2a; + typedef typename sequence*>::kernel_2a psequence_2a; + typedef typename map,void*>::kernel_1a map_1a_string; + + public: + + typedef cmd_line_parser_kernel_1 kernel_1a; + typedef cmd_line_parser_kernel_c kernel_1a_c; + typedef cmd_line_parser_print_1 print_1a_c; + typedef cmd_line_parser_check_c > check_1a_c; + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename charT + > + class cmd_line_parser : public impl_cmd_line_parser::check_1a_c + { + public: + + // These typedefs are here for backwards compatibility with previous versions of dlib. + typedef cmd_line_parser kernel_1a; + typedef cmd_line_parser kernel_1a_c; + typedef cmd_line_parser print_1a; + typedef cmd_line_parser print_1a_c; + typedef cmd_line_parser check_1a; + typedef cmd_line_parser check_1a_c; + }; + + template < + typename charT + > + inline void swap ( + cmd_line_parser& a, + cmd_line_parser& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- + + typedef cmd_line_parser command_line_parser; + typedef cmd_line_parser wcommand_line_parser; + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSEr_ + diff --git a/dlib/cmd_line_parser/cmd_line_parser_check_1.h b/dlib/cmd_line_parser/cmd_line_parser_check_1.h new file mode 100644 index 0000000000000000000000000000000000000000..1736b4b56a03e0638428dbd26b5c24192dac4a15 --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_check_1.h @@ -0,0 +1,580 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_CHECk_1_ +#define DLIB_CMD_LINE_PARSER_CHECk_1_ + +#include "cmd_line_parser_kernel_abstract.h" +#include +#include +#include "../string.h" +#include + +namespace dlib +{ + + template < + typename clp_base + > + class cmd_line_parser_check_1 : public clp_base + { + + /*! + This extension doesn't add any state. + + !*/ + + + public: + typedef typename clp_base::char_type char_type; + typedef typename clp_base::string_type string_type; + + // ------------------------------------------------------------------------------------ + + class cmd_line_check_error : public dlib::error + { + friend class cmd_line_parser_check_1; + + cmd_line_check_error( + error_type t, + const string_type& opt_, + const string_type& arg_ + ) : + dlib::error(t), + opt(opt_), + opt2(), + arg(arg_), + required_opts() + { set_info_string(); } + + cmd_line_check_error( + error_type t, + const string_type& opt_, + const string_type& opt2_, + int // this is just to make this constructor different from the one above + ) : + dlib::error(t), + opt(opt_), + opt2(opt2_), + arg(), + required_opts() + { set_info_string(); } + + cmd_line_check_error ( + error_type t, + const string_type& opt_, + const std::vector& vect + ) : + dlib::error(t), + opt(opt_), + opt2(), + arg(), + required_opts(vect) + { set_info_string(); } + + cmd_line_check_error( + error_type t, + const string_type& opt_ + ) : + dlib::error(t), + opt(opt_), + opt2(), + arg(), + required_opts() + { set_info_string(); } + + ~cmd_line_check_error() throw() {} + + void set_info_string ( + ) + { + std::ostringstream sout; + switch (type) + { + case EINVALID_OPTION_ARG: + sout << "Command line error: '" << narrow(arg) << "' is not a valid argument to " + << "the '" << narrow(opt) << "' option."; + break; + case EMISSING_REQUIRED_OPTION: + if (required_opts.size() == 1) + { + sout << "Command line error: The '" << narrow(opt) << "' option requires the presence of " + << "the '" << required_opts[0] << "' option."; + } + else + { + sout << "Command line error: The '" << narrow(opt) << "' option requires the presence of " + << "one of the following options: "; + for (unsigned long i = 0; i < required_opts.size(); ++i) + { + if (i == required_opts.size()-2) + sout << "'" << required_opts[i] << "' or "; + else if (i == required_opts.size()-1) + sout << "'" << required_opts[i] << "'."; + else + sout << "'" << required_opts[i] << "', "; + } + } + break; + case EINCOMPATIBLE_OPTIONS: + sout << "Command line error: The '" << narrow(opt) << "' and '" << narrow(opt2) + << "' options cannot be given together on the command line."; + break; + case EMULTIPLE_OCCURANCES: + sout << "Command line error: The '" << narrow(opt) << "' option can only " + << "be given on the command line once."; + break; + default: + sout << "Command line error."; + break; + } + const_cast(info) = wrap_string(sout.str(),0,0); + } + + public: + const string_type opt; + const string_type opt2; + const string_type arg; + const std::vector required_opts; + }; + + // ------------------------------------------------------------------------------------ + + template < + typename T + > + void check_option_arg_type ( + const string_type& option_name + ) const; + + template < + typename T + > + void check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const; + + template < + typename T, + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_incompatible_options ( + const char_type* (&option_set)[length] + ) const; + + template < + size_t length + > + void check_one_time_options ( + const char_type* (&option_set)[length] + ) const; + + void check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const; + + void check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const; + + template < + size_t length + > + void check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const; + + template < + size_t length + > + void check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const; + + template < + size_t parent_length, + size_t sub_length + > + void check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const; + }; + + template < + typename clp_base + > + inline void swap ( + cmd_line_parser_check_1& a, + cmd_line_parser_check_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_1:: + check_option_arg_type ( + const string_type& option_name + ) const + { + try + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + string_cast(opt.argument(i,j)); + } + } + } + catch (string_cast_error& e) + { + throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); + } + } + +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_1:: + check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const + { + try + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + T temp(string_cast(opt.argument(i,j))); + if (temp < first || last < temp) + { + throw cmd_line_check_error( + EINVALID_OPTION_ARG, + option_name, + opt.argument(i,j) + ); + } + } + } + } + catch (string_cast_error& e) + { + throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < typename T, size_t length > + void cmd_line_parser_check_1:: + check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const + { + try + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + T temp(string_cast(opt.argument(i,j))); + size_t k = 0; + for (; k < length; ++k) + { + if (arg_set[k] == temp) + break; + } + if (k == length) + { + throw cmd_line_check_error( + EINVALID_OPTION_ARG, + option_name, + opt.argument(i,j) + ); + } + } + } + } + catch (string_cast_error& e) + { + throw cmd_line_check_error(EINVALID_OPTION_ARG,option_name,e.info); + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const + { + const typename clp_base::option_type& opt = this->option(option_name); + const unsigned long number_of_arguments = opt.number_of_arguments(); + const unsigned long count = opt.count(); + for (unsigned long i = 0; i < number_of_arguments; ++i) + { + for (unsigned long j = 0; j < count; ++j) + { + size_t k = 0; + for (; k < length; ++k) + { + if (arg_set[k] == opt.argument(i,j)) + break; + } + if (k == length) + { + throw cmd_line_check_error( + EINVALID_OPTION_ARG, + option_name, + opt.argument(i,j) + ); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_incompatible_options ( + const char_type* (&option_set)[length] + ) const + { + for (size_t i = 0; i < length; ++i) + { + for (size_t j = i+1; j < length; ++j) + { + if (this->option(option_set[i]).count() > 0 && + this->option(option_set[j]).count() > 0 ) + { + throw cmd_line_check_error( + EINCOMPATIBLE_OPTIONS, + option_set[i], + option_set[j], + 0 // this argument has no meaning and is only here to make this + // call different from the other constructor + ); + } + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_1:: + check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const + { + if (this->option(option_name1).count() > 0 && + this->option(option_name2).count() > 0 ) + { + throw cmd_line_check_error( + EINCOMPATIBLE_OPTIONS, + option_name1, + option_name2, + 0 // this argument has no meaning and is only here to make this + // call different from the other constructor + ); + } + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_1:: + check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const + { + if (this->option(parent_option).count() == 0) + { + if (this->option(sub_option).count() != 0) + { + std::vector vect; + vect.resize(1); + vect[0] = parent_option; + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option, vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const + { + if (this->option(parent_option).count() == 0) + { + size_t i = 0; + for (; i < length; ++i) + { + if (this->option(sub_option_set[i]).count() > 0) + break; + } + if (i != length) + { + std::vector vect; + vect.resize(1); + vect[0] = parent_option; + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option_set[i], vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const + { + // first check if the sub_option is present + if (this->option(sub_option).count() > 0) + { + // now check if any of the parents are present + bool parents_present = false; + for (size_t i = 0; i < length; ++i) + { + if (this->option(parent_option_set[i]).count() > 0) + { + parents_present = true; + break; + } + } + + if (!parents_present) + { + std::vector vect(parent_option_set, parent_option_set+length); + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option, vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t parent_length, size_t sub_length > + void cmd_line_parser_check_1:: + check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const + { + // first check if any of the parent options are present + bool parents_present = false; + for (size_t i = 0; i < parent_length; ++i) + { + if (this->option(parent_option_set[i]).count() > 0) + { + parents_present = true; + break; + } + } + + if (!parents_present) + { + // none of these sub options should be present + size_t i = 0; + for (; i < sub_length; ++i) + { + if (this->option(sub_option_set[i]).count() > 0) + break; + } + if (i != sub_length) + { + std::vector vect(parent_option_set, parent_option_set+parent_length); + throw cmd_line_check_error( EMISSING_REQUIRED_OPTION, sub_option_set[i], vect); + } + } + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_1:: + check_one_time_options ( + const char_type* (&option_set)[length] + ) const + { + size_t i = 0; + for (; i < length; ++i) + { + if (this->option(option_set[i]).count() > 1) + break; + } + if (i != length) + { + throw cmd_line_check_error( + EMULTIPLE_OCCURANCES, + option_set[i] + ); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_CHECk_1_ + + diff --git a/dlib/cmd_line_parser/cmd_line_parser_check_c.h b/dlib/cmd_line_parser/cmd_line_parser_check_c.h new file mode 100644 index 0000000000000000000000000000000000000000..7ff858e8985efc86526484df824b13f4dcfa2b94 --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_check_c.h @@ -0,0 +1,453 @@ +// Copyright (C) 2006 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_CHECk_C_ +#define DLIB_CMD_LINE_PARSER_CHECk_C_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include +#include "../interfaces/cmd_line_parser_option.h" +#include "../string.h" + +namespace dlib +{ + + template < + typename clp_check + > + class cmd_line_parser_check_c : public clp_check + { + public: + + typedef typename clp_check::char_type char_type; + typedef typename clp_check::string_type string_type; + + template < + typename T + > + void check_option_arg_type ( + const string_type& option_name + ) const; + + template < + typename T + > + void check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const; + + template < + typename T, + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const; + + template < + size_t length + > + void check_incompatible_options ( + const char_type* (&option_set)[length] + ) const; + + template < + size_t length + > + void check_one_time_options ( + const char_type* (&option_set)[length] + ) const; + + void check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const; + + void check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const; + + template < + size_t length + > + void check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const; + + template < + size_t length + > + void check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const; + + template < + size_t parent_length, + size_t sub_length + > + void check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const; + }; + + + template < + typename clp_check + > + inline void swap ( + cmd_line_parser_check_c& a, + cmd_line_parser_check_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_c:: + check_option_arg_type ( + const string_type& option_name + ) const + { + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), + "\tvoid cmd_line_parser_check::check_option_arg_type()" + << "\n\tYou must have already parsed the command line and option_name must be valid." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + ); + + clp_check::template check_option_arg_type(option_name); + } + +// ---------------------------------------------------------------------------------------- + + template + template + void cmd_line_parser_check_c:: + check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const + { + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name) && + first <= last, + "\tvoid cmd_line_parser_check::check_option_arg_range()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + << "\n\tfirst: " << first + << "\n\tlast: " << last + ); + + clp_check::check_option_arg_range(option_name,first,last); + } + +// ---------------------------------------------------------------------------------------- + + template + template < typename T, size_t length > + void cmd_line_parser_check_c:: + check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const + { + COMPILE_TIME_ASSERT(is_pointer_type::value == false); + + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), + "\tvoid cmd_line_parser_check::check_option_arg_range()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + ); + + clp_check::check_option_arg_range(option_name,arg_set); + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name), + "\tvoid cmd_line_parser_check::check_option_arg_range()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name): " << ((this->option_is_defined(option_name))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name: " << option_name + ); + + clp_check::check_option_arg_range(option_name,arg_set); + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_incompatible_options ( + const char_type* (&option_set)[length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_set[i]), + "\tvoid cmd_line_parser_check::check_incompatible_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_set[i]): " << ((this->option_is_defined(option_set[i]))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_set[i]: " << option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + clp_check::check_incompatible_options(option_set); + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_c:: + check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_name1) && + this->option_is_defined(option_name2), + "\tvoid cmd_line_parser_check::check_incompatible_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_name1): " << ((this->option_is_defined(option_name1))?"true":"false") + << "\n\toption_is_defined(option_name2): " << ((this->option_is_defined(option_name2))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_name1: " << option_name1 + << "\n\toption_name2: " << option_name2 + ); + + clp_check::check_incompatible_options(option_name1,option_name2); + } + +// ---------------------------------------------------------------------------------------- + + template + void cmd_line_parser_check_c:: + check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(parent_option) && + this->option_is_defined(sub_option), + "\tvoid cmd_line_parser_check::check_sub_option()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\tparsed_line(): " << this->parsed_line() + << "\n\toption_is_defined(parent_option): " << this->option_is_defined(parent_option) + << "\n\toption_is_defined(sub_option): " << this->option_is_defined(sub_option) + << "\n\tparent_option: " << parent_option + << "\n\tsub_option: " << sub_option + ); + clp_check::check_sub_option(parent_option,sub_option); + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->option_is_defined(sub_option_set[i]), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(sub_option_set[i]): " + << ((this->option_is_defined(sub_option_set[i]))?"true":"false") + << "\n\tsub_option_set[i]: " << sub_option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(parent_option), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(parent_option): " << ((this->option_is_defined(parent_option))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\tparent_option: " << parent_option + ); + clp_check::check_sub_options(parent_option,sub_option_set); + + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->option_is_defined(parent_option_set[i]), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(parent_option_set[i]): " + << ((this->option_is_defined(parent_option_set[i]))?"true":"false") + << "\n\tparent_option_set[i]: " << parent_option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(sub_option), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(sub_option): " << ((this->option_is_defined(sub_option))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\tsub_option: " << sub_option + ); + clp_check::check_sub_options(parent_option_set,sub_option); + + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t parent_length, size_t sub_length > + void cmd_line_parser_check_c:: + check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < sub_length; ++i) + { + DLIB_CASSERT( this->option_is_defined(sub_option_set[i]), + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(sub_option_set[i]): " + << ((this->option_is_defined(sub_option_set[i]))?"true":"false") + << "\n\tsub_option_set[i]: " << sub_option_set[i] + << "\n\ti: " << static_cast(i) + ); + } + + for (size_t i = 0; i < parent_length; ++i) + { + DLIB_CASSERT( this->option_is_defined(parent_option_set[i]), + "\tvoid cmd_line_parser_check::check_parent_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(parent_option_set[i]): " + << ((this->option_is_defined(parent_option_set[i]))?"true":"false") + << "\n\tparent_option_set[i]: " << parent_option_set[i] + << "\n\ti: " << static_cast(i) + ); + } + + + + DLIB_CASSERT( this->parsed_line() == true , + "\tvoid cmd_line_parser_check::check_sub_options()" + << "\n\tYou must have parsed the command line before you call this function." + << "\n\tthis: " << this + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + ); + + clp_check::check_sub_options(parent_option_set,sub_option_set); + + } + +// ---------------------------------------------------------------------------------------- + + template + template < size_t length > + void cmd_line_parser_check_c:: + check_one_time_options ( + const char_type* (&option_set)[length] + ) const + { + // make sure requires clause is not broken + for (size_t i = 0; i < length; ++i) + { + DLIB_CASSERT( this->parsed_line() == true && this->option_is_defined(option_set[i]), + "\tvoid cmd_line_parser_check::check_one_time_options()" + << "\n\tSee the requires clause for this function." + << "\n\tthis: " << this + << "\n\toption_is_defined(option_set[i]): " << ((this->option_is_defined(option_set[i]))?"true":"false") + << "\n\tparsed_line(): " << ((this->parsed_line())?"true":"false") + << "\n\toption_set[i]: " << option_set[i] + << "\n\ti: " << static_cast(i) + ); + + } + clp_check::check_one_time_options(option_set); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_CHECk_C_ + diff --git a/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h b/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..68ea5a135078d8063b372100467d7e96691f617c --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_kernel_1.h @@ -0,0 +1,799 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_KERNEl_1_ +#define DLIB_CMD_LINE_PARSER_KERNEl_1_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include +#include +#include "../interfaces/enumerable.h" +#include "../interfaces/cmd_line_parser_option.h" +#include "../assert.h" +#include "../string.h" + +namespace dlib +{ + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + class cmd_line_parser_kernel_1 : public enumerable > + { + /*! + REQUIREMENTS ON map + is an implementation of map/map_kernel_abstract.h + is instantiated to map items of type std::basic_string to void* + + REQUIREMENTS ON sequence + is an implementation of sequence/sequence_kernel_abstract.h and + is instantiated with std::basic_string + + REQUIREMENTS ON sequence2 + is an implementation of sequence/sequence_kernel_abstract.h and + is instantiated with std::basic_string* + + INITIAL VALUE + options.size() == 0 + argv.size() == 0 + have_parsed_line == false + + CONVENTION + have_parsed_line == parsed_line() + argv[index] == operator[](index) + argv.size() == number_of_arguments() + *((option_t*)options[name]) == option(name) + options.is_in_domain(name) == option_is_defined(name) + !*/ + + + + + public: + + typedef charT char_type; + typedef std::basic_string string_type; + typedef cmd_line_parser_option option_type; + + // exception class + class cmd_line_parse_error : public dlib::error + { + void set_info_string ( + ) + { + std::ostringstream sout; + switch (type) + { + case EINVALID_OPTION: + sout << "Command line error: '" << narrow(item) << "' is not a valid option."; + break; + case ETOO_FEW_ARGS: + if (num > 1) + { + sout << "Command line error: The '" << narrow(item) << "' option requires " << num + << " arguments."; + } + else + { + sout << "Command line error: The '" << narrow(item) << "' option requires " << num + << " argument."; + } + break; + case ETOO_MANY_ARGS: + sout << "Command line error: The '" << narrow(item) << "' option does not take any arguments.\n"; + break; + default: + sout << "Command line error."; + break; + } + const_cast(info) = wrap_string(sout.str(),0,0); + } + + public: + cmd_line_parse_error( + error_type t, + const std::basic_string& _item + ) : + dlib::error(t), + item(_item), + num(0) + { set_info_string();} + + cmd_line_parse_error( + error_type t, + const std::basic_string& _item, + unsigned long _num + ) : + dlib::error(t), + item(_item), + num(_num) + { set_info_string();} + + cmd_line_parse_error( + ) : + dlib::error(), + item(), + num(0) + { set_info_string();} + + ~cmd_line_parse_error() throw() {} + + const std::basic_string item; + const unsigned long num; + }; + + + private: + + class option_t : public cmd_line_parser_option + { + /*! + INITIAL VALUE + options.size() == 0 + + CONVENTION + name_ == name() + description_ == description() + number_of_arguments_ == number_of_arguments() + options[N][arg] == argument(arg,N) + num_present == count() + !*/ + + friend class cmd_line_parser_kernel_1; + + public: + + const std::basic_string& name ( + ) const { return name_; } + + const std::basic_string& group_name ( + ) const { return group_name_; } + + const std::basic_string& description ( + ) const { return description_; } + + unsigned long number_of_arguments( + ) const { return number_of_arguments_; } + + unsigned long count ( + ) const { return num_present; } + + const std::basic_string& argument ( + unsigned long arg, + unsigned long N + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( N < count() && arg < number_of_arguments(), + "\tconst string_type& cmd_line_parser_option::argument(unsigned long,unsigned long)" + << "\n\tInvalid arguments were given to this function." + << "\n\tthis: " << this + << "\n\tN: " << N + << "\n\targ: " << arg + << "\n\tname(): " << narrow(name()) + << "\n\tcount(): " << count() + << "\n\tnumber_of_arguments(): " << number_of_arguments() + ); + + return options[N][arg]; + } + + protected: + + option_t ( + ) : + num_present(0) + {} + + ~option_t() + { + clear(); + } + + private: + + void clear() + /*! + ensures + - #count() == 0 + - clears everything out of options and frees memory + !*/ + { + for (unsigned long i = 0; i < options.size(); ++i) + { + delete [] options[i]; + } + options.clear(); + num_present = 0; + } + + // data members + std::basic_string name_; + std::basic_string group_name_; + std::basic_string description_; + sequence2 options; + unsigned long number_of_arguments_; + unsigned long num_present; + + + + // restricted functions + option_t(option_t&); // copy constructor + option_t& operator=(option_t&); // assignment operator + }; + + // -------------------------- + + public: + + cmd_line_parser_kernel_1 ( + ); + + virtual ~cmd_line_parser_kernel_1 ( + ); + + void clear( + ); + + void parse ( + int argc, + const charT** argv + ); + + void parse ( + int argc, + charT** argv + ) + { + parse(argc, const_cast(argv)); + } + + bool parsed_line( + ) const; + + bool option_is_defined ( + const string_type& name + ) const; + + void add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments = 0 + ); + + void set_group_name ( + const string_type& group_name + ); + + string_type get_group_name ( + ) const { return group_name; } + + const cmd_line_parser_option& option ( + const string_type& name + ) const; + + unsigned long number_of_arguments( + ) const; + + const string_type& operator[] ( + unsigned long index + ) const; + + void swap ( + cmd_line_parser_kernel_1& item + ); + + // functions from the enumerable interface + bool at_start ( + ) const { return options.at_start(); } + + void reset ( + ) const { options.reset(); } + + bool current_element_valid ( + ) const { return options.current_element_valid(); } + + const cmd_line_parser_option& element ( + ) const { return *static_cast*>(options.element().value()); } + + cmd_line_parser_option& element ( + ) { return *static_cast*>(options.element().value()); } + + bool move_next ( + ) const { return options.move_next(); } + + size_t size ( + ) const { return options.size(); } + + private: + + // data members + map options; + sequence argv; + bool have_parsed_line; + string_type group_name; + + // restricted functions + cmd_line_parser_kernel_1(cmd_line_parser_kernel_1&); // copy constructor + cmd_line_parser_kernel_1& operator=(cmd_line_parser_kernel_1&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + inline void swap ( + cmd_line_parser_kernel_1& a, + cmd_line_parser_kernel_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + cmd_line_parser_kernel_1:: + cmd_line_parser_kernel_1 ( + ) : + have_parsed_line(false) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + cmd_line_parser_kernel_1:: + ~cmd_line_parser_kernel_1 ( + ) + { + // delete all option_t objects in options + options.reset(); + while (options.move_next()) + { + delete static_cast(options.element().value()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + clear( + ) + { + have_parsed_line = false; + argv.clear(); + + + // delete all option_t objects in options + options.reset(); + while (options.move_next()) + { + delete static_cast(options.element().value()); + } + options.clear(); + reset(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + parse ( + int argc_, + const charT** argv + ) + { + using namespace std; + + // make sure there aren't any arguments hanging around from the last time + // parse was called + this->argv.clear(); + + // make sure that the options have been cleared of any arguments since + // the last time parse() was called + if (have_parsed_line) + { + options.reset(); + while (options.move_next()) + { + static_cast(options.element().value())->clear(); + } + options.reset(); + } + + // this tells us if we have seen -- on the command line all by itself + // or not. + bool escape = false; + + const unsigned long argc = static_cast(argc_); + try + { + + for (unsigned long i = 1; i < argc; ++i) + { + if (argv[i][0] == _dT(charT,'-') && !escape) + { + // we are looking at the start of an option + + // -------------------------------------------------------------------- + if (argv[i][1] == _dT(charT,'-')) + { + // we are looking at the start of a "long named" option + string_type temp = &argv[i][2]; + string_type first_argument; + typename string_type::size_type pos = temp.find_first_of(_dT(charT,'=')); + // This variable will be 1 if there is an argument supplied via the = sign + // and 0 otherwise. + unsigned long extra_argument = 0; + if (pos != string_type::npos) + { + // there should be an extra argument + extra_argument = 1; + first_argument = temp.substr(pos+1); + temp = temp.substr(0,pos); + } + + // make sure this name is defined + if (!options.is_in_domain(temp)) + { + // the long name is not a valid option + if (argv[i][2] == _dT(charT,'\0')) + { + // there was nothing after the -- on the command line + escape = true; + continue; + } + else + { + // there was something after the command line but it + // wasn't a valid option + throw cmd_line_parse_error(EINVALID_OPTION,temp); + } + } + + + option_t* o = static_cast(options[temp]); + + // check the number of arguments after this option and make sure + // it is correct + if (argc + extra_argument <= o->number_of_arguments() + i) + { + // there are too few arguments + throw cmd_line_parse_error(ETOO_FEW_ARGS,temp,o->number_of_arguments()); + } + if (extra_argument && first_argument.size() == 0 ) + { + // if there would be exactly the right number of arguments if + // the first_argument wasn't empty + if (argc == o->number_of_arguments() + i) + throw cmd_line_parse_error(ETOO_FEW_ARGS,temp,o->number_of_arguments()); + else + { + // in this case we just ignore the trailing = and parse everything + // the same. + extra_argument = 0; + } + } + // you can't force an option that doesn't have any arguments to take + // one by using the --option=arg syntax + if (extra_argument == 1 && o->number_of_arguments() == 0) + { + throw cmd_line_parse_error(ETOO_MANY_ARGS,temp); + } + + + + + + + // at this point we know that the option is ok and we should + // populate its options object + if (o->number_of_arguments() > 0) + { + + string_type* stemp = new string_type[o->number_of_arguments()]; + unsigned long j = 0; + + // add the argument after the = sign if one is present + if (extra_argument) + { + stemp[0] = first_argument; + ++j; + } + + for (; j < o->number_of_arguments(); ++j) + { + stemp[j] = argv[i+j+1-extra_argument]; + } + o->options.add(o->options.size(),stemp); + } + o->num_present += 1; + + + // adjust the value of i to account for the arguments to + // this option + i += o->number_of_arguments() - extra_argument; + } + // -------------------------------------------------------------------- + else + { + // we are looking at the start of a list of a single char options + + // make sure there is something in this string other than - + if (argv[i][1] == _dT(charT,'\0')) + { + throw cmd_line_parse_error(); + } + + string_type temp = &argv[i][1]; + const typename string_type::size_type num = temp.size(); + for (unsigned long k = 0; k < num; ++k) + { + string_type name; + // Doing this instead of name = temp[k] seems to avoid a bug in g++ (Ubuntu/Linaro 4.5.2-8ubuntu4) 4.5.2 + // which results in name[0] having the wrong value. + name.resize(1); + name[0] = temp[k]; + + + // make sure this name is defined + if (!options.is_in_domain(name)) + { + // the name is not a valid option + throw cmd_line_parse_error(EINVALID_OPTION,name); + } + + option_t* o = static_cast(options[name]); + + // if there are chars immediately following this option + int delta = 0; + if (num != k+1) + { + delta = 1; + } + + // check the number of arguments after this option and make sure + // it is correct + if (argc + delta <= o->number_of_arguments() + i) + { + // there are too few arguments + std::ostringstream sout; + throw cmd_line_parse_error(ETOO_FEW_ARGS,name,o->number_of_arguments()); + } + + + o->num_present += 1; + + // at this point we know that the option is ok and we should + // populate its options object + if (o->number_of_arguments() > 0) + { + string_type* stemp = new string_type[o->number_of_arguments()]; + if (delta == 1) + { + temp = &argv[i][2+k]; + k = (unsigned long)num; // this ensures that the argument to this + // option isn't going to be treated as a + // list of options + + stemp[0] = temp; + } + for (unsigned long j = 0; j < o->number_of_arguments()-delta; ++j) + { + stemp[j+delta] = argv[i+j+1]; + } + o->options.add(o->options.size(),stemp); + + // adjust the value of i to account for the arguments to + // this option + i += o->number_of_arguments()-delta; + } + } // for (unsigned long k = 0; k < num; ++k) + } + // -------------------------------------------------------------------- + + } + else + { + // this is just a normal argument + string_type temp = argv[i]; + this->argv.add(this->argv.size(),temp); + } + + } + have_parsed_line = true; + + } + catch (...) + { + have_parsed_line = false; + + // clear all the option objects + options.reset(); + while (options.move_next()) + { + static_cast(options.element().value())->clear(); + } + options.reset(); + + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + bool cmd_line_parser_kernel_1:: + parsed_line( + ) const + { + return have_parsed_line; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + bool cmd_line_parser_kernel_1:: + option_is_defined ( + const string_type& name + ) const + { + return options.is_in_domain(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + set_group_name ( + const string_type& group_name_ + ) + { + group_name = group_name_; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments + ) + { + option_t* temp = new option_t; + try + { + temp->name_ = name; + temp->group_name_ = group_name; + temp->description_ = description; + temp->number_of_arguments_ = number_of_arguments; + void* t = temp; + string_type n(name); + options.add(n,t); + }catch (...) { delete temp; throw;} + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + const cmd_line_parser_option& cmd_line_parser_kernel_1:: + option ( + const string_type& name + ) const + { + return *static_cast*>(options[name]); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + unsigned long cmd_line_parser_kernel_1:: + number_of_arguments( + ) const + { + return argv.size(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + const std::basic_string& cmd_line_parser_kernel_1:: + operator[] ( + unsigned long index + ) const + { + return argv[index]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename charT, + typename map, + typename sequence, + typename sequence2 + > + void cmd_line_parser_kernel_1:: + swap ( + cmd_line_parser_kernel_1& item + ) + { + options.swap(item.options); + argv.swap(item.argv); + exchange(have_parsed_line,item.have_parsed_line); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_KERNEl_1_ + diff --git a/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h b/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..ac9036132c578a2fae2de5b790660198864b37b5 --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_kernel_abstract.h @@ -0,0 +1,673 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ +#ifdef DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include +#include "../interfaces/enumerable.h" +#include "../interfaces/cmd_line_parser_option.h" +#include +#include + +namespace dlib +{ + + template < + typename charT + > + class cmd_line_parser : public enumerable > + { + /*! + REQUIREMENTS ON charT + Must be an integral type suitable for storing characters. (e.g. char + or wchar_t) + + INITIAL VALUE + - parsed_line() == false + - option_is_defined(x) == false, for all values of x + - get_group_name() == "" + + ENUMERATION ORDER + The enumerator will enumerate over all the options defined in *this + in alphabetical order according to the name of the option. + + POINTERS AND REFERENCES TO INTERNAL DATA + parsed_line(), option_is_defined(), option(), number_of_arguments(), + operator[](), and swap() functions do not invalidate pointers or + references to internal data. All other functions have no such guarantee. + + + WHAT THIS OBJECT REPRESENTS + This object represents a command line parser. + The command lines must match the following BNF. + + command_line ::= { | } [ -- {} ] + program_name ::= + arg ::= any that does not start with - + option_arg ::= + option_name ::= + long_option_name ::= { | - } + options ::= - {} {} | + -- [=] { } + char ::= any character other than - or = + word ::= any string from argv where argv is the second + parameter to main() + sword ::= any suffix of a string from argv where argv is the + second parameter to main() + bword ::= This is an empty string which denotes the beginning of a + . + + + Options with arguments: + An option with N arguments will consider the next N swords to be + its arguments. + + so for example, if we have an option o that expects 2 arguments + then the following are a few legal examples: + + program -o arg1 arg2 general_argument + program -oarg1 arg2 general_argument + + arg1 and arg2 are associated with the option o and general_argument + is not. + + Arguments not associated with an option: + An argument that is not associated with an option is considered a + general command line argument and is indexed by operator[] defined + by the cmd_line_parser object. Additionally, if the string + "--" appears in the command line all by itself then all words + following it are considered to be general command line arguments. + + + Consider the following two examples involving a command line and + a cmd_line_parser object called parser. + + Example 1: + command line: program general_arg1 -o arg1 arg2 general_arg2 + Then the following is true (assuming the o option is defined + and takes 2 arguments). + + parser[0] == "general_arg1" + parser[1] == "general_arg2" + parser.number_of_arguments() == 2 + parser.option("o").argument(0) == "arg1" + parser.option("o").argument(1) == "arg2" + parser.option("o").count() == 1 + + Example 2: + command line: program general_arg1 -- -o arg1 arg2 general_arg2 + Then the following is true (the -- causes everything following + it to be treated as a general argument). + + parser[0] == "general_arg1" + parser[1] == "-o" + parser[2] == "arg1" + parser[3] == "arg2" + parser[4] == "general_arg2" + parser.number_of_arguments() == 5 + parser.option("o").count() == 0 + !*/ + + public: + + typedef charT char_type; + typedef std::basic_string string_type; + typedef cmd_line_parser_option option_type; + + // exception class + class cmd_line_parse_error : public dlib::error + { + /*! + GENERAL + This exception is thrown if there is an error detected in a + command line while it is being parsed. You can consult this + object's type and item members to determine the nature of the + error. (note that the type member is inherited from dlib::error). + + INTERPRETING THIS EXCEPTION + - if (type == EINVALID_OPTION) then + - There was an undefined option on the command line + - item == The invalid option that was on the command line + - if (type == ETOO_FEW_ARGS) then + - An option was given on the command line but it was not + supplied with the required number of arguments. + - item == The name of this option. + - num == The number of arguments expected by this option. + - if (type == ETOO_MANY_ARGS) then + - An option was given on the command line such as --option=arg + but this option doesn't take any arguments. + - item == The name of this option. + !*/ + public: + const std::basic_string item; + const unsigned long num; + }; + + // -------------------------- + + cmd_line_parser ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~cmd_line_parser ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable + until clear() is called and succeeds + !*/ + + void parse ( + int argc, + const charT** argv + ); + /*! + requires + - argv == an array of strings that was obtained from the second argument + of the function main(). + (i.e. argv[0] should be the token, argv[1] should be + an or token, etc.) + - argc == the number of strings in argv + ensures + - parses the command line given by argc and argv + - #parsed_line() == true + - #at_start() == true + throws + - std::bad_alloc + if this exception is thrown then #*this is unusable until clear() + is called successfully + - cmd_line_parse_error + This exception is thrown if there is an error parsing the command line. + If this exception is thrown then #parsed_line() == false and all + options will have their count() set to 0 but otherwise there will + be no effect (i.e. all registered options will remain registered). + !*/ + + void parse ( + int argc, + charT** argv + ); + /*! + This just calls this->parse(argc,argv) and performs the necessary const_cast + on argv. + !*/ + + bool parsed_line( + ) const; + /*! + ensures + - returns true if parse() has been called successfully + - returns false otherwise + !*/ + + bool option_is_defined ( + const string_type& name + ) const; + /*! + ensures + - returns true if the option has been added to the parser object + by calling add_option(name). + - returns false otherwise + !*/ + + void add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments = 0 + ); + /*! + requires + - parsed_line() == false + - option_is_defined(name) == false + - name does not contain any ' ', '\t', '\n', or '=' characters + - name[0] != '-' + - name.size() > 0 + ensures + - #option_is_defined(name) == true + - #at_start() == true + - #option(name).count() == 0 + - #option(name).description() == description + - #option(name).number_of_arguments() == number_of_arguments + - #option(name).group_name() == get_group_name() + throws + - std::bad_alloc + if this exception is thrown then the add_option() function has no + effect + !*/ + + const option_type& option ( + const string_type& name + ) const; + /*! + requires + - option_is_defined(name) == true + ensures + - returns the option specified by name + !*/ + + unsigned long number_of_arguments( + ) const; + /*! + requires + - parsed_line() == true + ensures + - returns the number of arguments present in the command line. + This count does not include options or their arguments. Only + arguments unrelated to any option are counted. + !*/ + + const string_type& operator[] ( + unsigned long N + ) const; + /*! + requires + - parsed_line() == true + - N < number_of_arguments() + ensures + - returns the Nth command line argument + !*/ + + void swap ( + cmd_line_parser& item + ); + /*! + ensures + - swaps *this and item + !*/ + + void print_options ( + std::basic_ostream& out + ) const; + /*! + ensures + - prints all the command line options to out. + - #at_start() == true + throws + - any exception. + if an exception is thrown then #at_start() == true but otherwise + it will have no effect on the state of #*this. + !*/ + + void print_options ( + ) const; + /*! + ensures + - prints all the command line options to cout. + - #at_start() == true + throws + - any exception. + if an exception is thrown then #at_start() == true but otherwise + it will have no effect on the state of #*this. + !*/ + + string_type get_group_name ( + ) const; + /*! + ensures + - returns the current group name. This is the group new options will be + added into when added via add_option(). + - The group name of an option is used by print_options(). In particular, + it groups all options with the same group name together and displays them + under a title containing the text of the group name. This allows you to + group similar options together in the output of print_options(). + - A group name of "" (i.e. the empty string) means that no group name is + set. + !*/ + + void set_group_name ( + const string_type& group_name + ); + /*! + ensures + - #get_group_name() == group_name + !*/ + + // ------------------------------------------------------------- + // Input Validation Tools + // ------------------------------------------------------------- + + class cmd_line_check_error : public dlib::error + { + /*! + This is the exception thrown by the check_*() routines if they find a + command line error. The interpretation of the member variables is defined + below in each check_*() routine. + !*/ + + public: + const string_type opt; + const string_type opt2; + const string_type arg; + const std::vector required_opts; + }; + + template < + typename T + > + void check_option_arg_type ( + const string_type& option_name + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + - T is not a pointer type + ensures + - all the arguments for the given option are convertible + by string_cast() to an object of type T. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + typename T + > + void check_option_arg_range ( + const string_type& option_name, + const T& first, + const T& last + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + - first <= last + - T is not a pointer type + ensures + - all the arguments for the given option are convertible + by string_cast() to an object of type T and the resulting value is + in the range first to last inclusive. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + typename T, + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const T (&arg_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + - T is not a pointer type + ensures + - for each argument to the given option: + - this argument is convertible by string_cast() to an object of + type T and the resulting value is equal to some element in the + arg_set array. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + size_t length + > + void check_option_arg_range ( + const string_type& option_name, + const char_type* (&arg_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name) == true + ensures + - for each argument to the given option: + - there is a string in the arg_set array that is equal to this argument. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINVALID_OPTION_ARG + - opt == option_name + - arg == the text of the offending argument + !*/ + + template < + size_t length + > + void check_one_time_options ( + const char_type* (&option_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - for all valid i: + - option_is_defined(option_set[i]) == true + ensures + - all the options in the option_set array occur at most once on the + command line. + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMULTIPLE_OCCURANCES + - opt == the option that occurred more than once on the command line. + !*/ + + void check_incompatible_options ( + const string_type& option_name1, + const string_type& option_name2 + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(option_name1) == true + - option_is_defined(option_name2) == true + ensures + - option(option_name1).count() == 0 || option(option_name2).count() == 0 + (i.e. at most, only one of the options is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINCOMPATIBLE_OPTIONS + - opt == option_name1 + - opt2 == option_name2 + !*/ + + template < + size_t length + > + void check_incompatible_options ( + const char_type* (&option_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - for all valid i: + - option_is_defined(option_set[i]) == true + ensures + - At most only one of the options in the array option_set has a count() + greater than 0. (i.e. at most, only one of the options is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EINCOMPATIBLE_OPTIONS + - opt == One of the incompatible options found. + - opt2 == The next incompatible option found. + !*/ + + void check_sub_option ( + const string_type& parent_option, + const string_type& sub_option + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(parent_option) == true + - option_is_defined(sub_option) == true + ensures + - if (option(parent_option).count() == 0) then + - option(sub_option).count() == 0 + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == sub_option. + - required_opts == a vector that contains only parent_option. + !*/ + + template < + size_t length + > + void check_sub_options ( + const char_type* (&parent_option_set)[length], + const string_type& sub_option + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(sub_option) == true + - for all valid i: + - option_is_defined(parent_option_set[i] == true + ensures + - if (option(sub_option).count() > 0) then + - At least one of the options in the array parent_option_set has a count() + greater than 0. (i.e. at least one of the options in parent_option_set + is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == the first option from the sub_option that is present. + - required_opts == a vector containing everything from parent_option_set. + !*/ + + template < + size_t length + > + void check_sub_options ( + const string_type& parent_option, + const char_type* (&sub_option_set)[length] + ) const; + /*! + requires + - parsed_line() == true + - option_is_defined(parent_option) == true + - for all valid i: + - option_is_defined(sub_option_set[i]) == true + ensures + - if (option(parent_option).count() == 0) then + - for all valid i: + - option(sub_option_set[i]).count() == 0 + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == the first option from the sub_option_set that is present. + - required_opts == a vector that contains only parent_option. + !*/ + + template < + size_t parent_length, + size_t sub_length + > + void check_sub_options ( + const char_type* (&parent_option_set)[parent_length], + const char_type* (&sub_option_set)[sub_length] + ) const; + /*! + requires + - parsed_line() == true + - for all valid i: + - option_is_defined(parent_option_set[i] == true + - for all valid j: + - option_is_defined(sub_option_set[j]) == true + ensures + - for all valid j: + - if (option(sub_option_set[j]).count() > 0) then + - At least one of the options in the array parent_option_set has a count() + greater than 0. (i.e. at least one of the options in parent_option_set + is currently present) + throws + - std::bad_alloc + - cmd_line_check_error + This exception is thrown if the ensures clause could not be satisfied. + The exception's members will be set as follows: + - type == EMISSING_REQUIRED_OPTION + - opt == the first option from the sub_option_set that is present. + - required_opts == a vector containing everything from parent_option_set. + !*/ + + + private: + + // restricted functions + cmd_line_parser(cmd_line_parser&); // copy constructor + cmd_line_parser& operator=(cmd_line_parser&); // assignment operator + + }; + +// ----------------------------------------------------------------------------------------- + + typedef cmd_line_parser command_line_parser; + typedef cmd_line_parser wcommand_line_parser; + +// ----------------------------------------------------------------------------------------- + + template < + typename charT + > + inline void swap ( + cmd_line_parser& a, + cmd_line_parser& b + ) { a.swap(b); } + /*! + provides a global swap function + !*/ + +// ----------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_KERNEl_ABSTRACT_ + diff --git a/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h b/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h new file mode 100644 index 0000000000000000000000000000000000000000..e80543018e17bc3700b3a824ad23954200be674d --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_kernel_c.h @@ -0,0 +1,203 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_KERNEl_C_ +#define DLIB_CMD_LINE_PARSER_KERNEl_C_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include +#include "../interfaces/cmd_line_parser_option.h" +#include "../string.h" + +namespace dlib +{ + + template < + typename clp_base + > + class cmd_line_parser_kernel_c : public clp_base + { + public: + + typedef typename clp_base::char_type char_type; + typedef typename clp_base::string_type string_type; + typedef typename clp_base::option_type option_type; + + void add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments = 0 + ); + + const option_type& option ( + const string_type& name + ) const; + + unsigned long number_of_arguments( + ) const; + + const option_type& element ( + ) const; + + option_type& element ( + ); + + const string_type& operator[] ( + unsigned long N + ) const; + + }; + + + template < + typename clp_base + > + inline void swap ( + cmd_line_parser_kernel_c& a, + cmd_line_parser_kernel_c& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + const typename clp_base::string_type& cmd_line_parser_kernel_c:: + operator[] ( + unsigned long N + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true && N < number_of_arguments(), + "\tvoid cmd_line_parser::operator[](unsigned long N)" + << "\n\tYou must specify a valid index N and the parser must have run already." + << "\n\tthis: " << this + << "\n\tN: " << N + << "\n\tparsed_line(): " << this->parsed_line() + << "\n\tnumber_of_arguments(): " << number_of_arguments() + ); + + return clp_base::operator[](N); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + void cmd_line_parser_kernel_c:: + add_option ( + const string_type& name, + const string_type& description, + unsigned long number_of_arguments + ) + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == false && + name.size() > 0 && + this->option_is_defined(name) == false && + name.find_first_of(_dT(char_type," \t\n=")) == string_type::npos && + name[0] != '-', + "\tvoid cmd_line_parser::add_option(const string_type&,const string_type&,unsigned long)" + << "\n\tsee the requires clause of add_option()" + << "\n\tthis: " << this + << "\n\tname.size(): " << static_cast(name.size()) + << "\n\tname: \"" << narrow(name) << "\"" + << "\n\tparsed_line(): " << (this->parsed_line()? "true" : "false") + << "\n\tis_option_defined(\"" << narrow(name) << "\"): " << (this->option_is_defined(name)? "true" : "false") + ); + + clp_base::add_option(name,description,number_of_arguments); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + const typename clp_base::option_type& cmd_line_parser_kernel_c:: + option ( + const string_type& name + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->option_is_defined(name) == true, + "\toption cmd_line_parser::option(const string_type&)" + << "\n\tto get an option it must be defined by a call to add_option()" + << "\n\tthis: " << this + << "\n\tname: \"" << narrow(name) << "\"" + ); + + return clp_base::option(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + unsigned long cmd_line_parser_kernel_c:: + number_of_arguments( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( this->parsed_line() == true , + "\tunsigned long cmd_line_parser::number_of_arguments()" + << "\n\tyou must parse the command line before you can find out how many arguments it has" + << "\n\tthis: " << this + ); + + return clp_base::number_of_arguments(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + const typename clp_base::option_type& cmd_line_parser_kernel_c:: + element ( + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tconst cmd_line_parser_option& cmd_line_parser::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return clp_base::element(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + typename clp_base::option_type& cmd_line_parser_kernel_c:: + element ( + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(this->current_element_valid() == true, + "\tcmd_line_parser_option& cmd_line_parser::element()" + << "\n\tyou can't access the current element if it doesn't exist" + << "\n\tthis: " << this + ); + + // call the real function + return clp_base::element(); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_KERNEl_C_ + diff --git a/dlib/cmd_line_parser/cmd_line_parser_print_1.h b/dlib/cmd_line_parser/cmd_line_parser_print_1.h new file mode 100644 index 0000000000000000000000000000000000000000..3f52c842f23fe57655e93b33b807fe7a68b2d821 --- /dev/null +++ b/dlib/cmd_line_parser/cmd_line_parser_print_1.h @@ -0,0 +1,205 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CMD_LINE_PARSER_PRINt_1_ +#define DLIB_CMD_LINE_PARSER_PRINt_1_ + +#include "cmd_line_parser_kernel_abstract.h" +#include "../algs.h" +#include "../string.h" +#include +#include +#include +#include +#include + +namespace dlib +{ + + template < + typename clp_base + > + class cmd_line_parser_print_1 : public clp_base + { + + public: + + void print_options ( + std::basic_ostream& out + ) const; + + void print_options ( + ) const + { + print_options(std::cout); + } + + }; + + template < + typename clp_base + > + inline void swap ( + cmd_line_parser_print_1& a, + cmd_line_parser_print_1& b + ) { a.swap(b); } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename clp_base + > + void cmd_line_parser_print_1:: + print_options ( + std::basic_ostream& out + ) const + { + typedef typename clp_base::char_type ct; + typedef std::basic_string string; + typedef typename string::size_type size_type; + + typedef std::basic_ostringstream ostringstream; + + try + { + + + size_type max_len = 0; + this->reset(); + + // this loop here is just the bottom loop but without the print statements. + // I'm doing this to figure out what len should be. + while (this->move_next()) + { + size_type len = 0; + len += 3; + if (this->element().name().size() > 1) + { + ++len; + } + len += this->element().name().size(); + + if (this->element().number_of_arguments() == 1) + { + len += 6; + } + else + { + for (unsigned long i = 0; i < this->element().number_of_arguments(); ++i) + { + len += 7; + if (i+1 > 9) + ++len; + } + } + + len += 3; + if (len < 33) + max_len = std::max(max_len,len); + } + + + // Make a separate ostringstream for each option group. We are going to write + // the output for each group to a separate ostringstream so that we can keep + // them grouped together in the final output. + std::map > groups; + this->reset(); + while(this->move_next()) + { + if (!groups[this->element().group_name()]) + groups[this->element().group_name()].reset(new ostringstream); + } + + + + + this->reset(); + + while (this->move_next()) + { + ostringstream& sout = *groups[this->element().group_name()]; + + size_type len = 0; + sout << _dT(ct,"\n -"); + len += 3; + if (this->element().name().size() > 1) + { + sout << _dT(ct,"-"); + ++len; + } + sout << this->element().name(); + len += this->element().name().size(); + + if (this->element().number_of_arguments() == 1) + { + sout << _dT(ct," "); + len += 6; + } + else + { + for (unsigned long i = 0; i < this->element().number_of_arguments(); ++i) + { + sout << _dT(ct," "); + len += 7; + if (i+1 > 9) + ++len; + } + } + + sout << _dT(ct," "); + len += 3; + + while (len < max_len) + { + ++len; + sout << _dT(ct," "); + } + + const unsigned long ml = static_cast(max_len); + // now print the description but make it wrap around nicely if it + // is to long to fit on one line. + if (len <= max_len) + sout << wrap_string(this->element().description(),0,ml); + else + sout << _dT(ct,"\n") << wrap_string(this->element().description(),ml,ml); + } + + // Only print out a generic Options: group name if there is an unnamed option + // present. + if (groups.count(string()) == 1) + out << _dT(ct,"Options:"); + + // Now print everything out + typename std::map >::iterator i; + for (i = groups.begin(); i != groups.end(); ++i) + { + // print the group name if we have one + if (i->first.size() != 0) + { + if (i != groups.begin()) + out << _dT(ct,"\n\n"); + out << i->first << _dT(ct,":"); + } + + // print the options in the group + out << i->second->str(); + } + out << _dT(ct,"\n\n"); + this->reset(); + } + catch (...) + { + this->reset(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CMD_LINE_PARSER_PRINt_1_ + diff --git a/dlib/cmd_line_parser/get_option.h b/dlib/cmd_line_parser/get_option.h new file mode 100644 index 0000000000000000000000000000000000000000..2c8d1644f7c2b58905b69ae563a6d8169dcae5a1 --- /dev/null +++ b/dlib/cmd_line_parser/get_option.h @@ -0,0 +1,181 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_GET_OPTiON_Hh_ +#define DLIB_GET_OPTiON_Hh_ + +#include "get_option_abstract.h" +#include "../string.h" +#include "../is_kind.h" + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class option_parse_error : public error + { + public: + option_parse_error(const std::string& option_string, const std::string& str): + error(EOPTION_PARSE,"Error parsing argument for option '" + option_string + "', offending string is '" + str + "'.") {} + }; + +// ---------------------------------------------------------------------------------------- + + template + T impl_config_reader_get_option ( + const config_reader_type& cr, + const std::string& option_name, + const std::string& full_option_name, + T default_value + ) + { + std::string::size_type pos = option_name.find_first_of("."); + if (pos == std::string::npos) + { + if (cr.is_key_defined(option_name)) + { + try{ return string_cast(cr[option_name]); } + catch (string_cast_error&) { throw option_parse_error(full_option_name, cr[option_name]); } + } + } + else + { + std::string block_name = option_name.substr(0,pos); + if (cr.is_block_defined(block_name)) + { + return impl_config_reader_get_option(cr.block(block_name), + option_name.substr(pos+1), + full_option_name, + default_value); + } + } + + return default_value; + } + +// ---------------------------------------------------------------------------------------- + + template + typename enable_if,T>::type get_option ( + const cr_type& cr, + const std::string& option_name, + T default_value + ) + { + return impl_config_reader_get_option(cr, option_name, option_name, default_value); + } + +// ---------------------------------------------------------------------------------------- + + template + typename disable_if,T>::type get_option ( + const parser_type& parser, + const std::string& option_name, + T default_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( parser.option_is_defined(option_name) == true && + parser.option(option_name).number_of_arguments() == 1, + "\t T get_option()" + << "\n\t option_name: " << option_name + << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) + << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() + ); + + if (parser.option(option_name)) + { + try + { + default_value = string_cast(parser.option(option_name).argument()); + } + catch (string_cast_error&) + { + throw option_parse_error(option_name, parser.option(option_name).argument()); + } + } + return default_value; + } + +// ---------------------------------------------------------------------------------------- + + template + typename disable_if,T>::type get_option ( + const parser_type& parser, + const cr_type& cr, + const std::string& option_name, + T default_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( parser.option_is_defined(option_name) == true && + parser.option(option_name).number_of_arguments() == 1, + "\t T get_option()" + << "\n\t option_name: " << option_name + << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) + << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() + ); + + if (parser.option(option_name)) + return get_option(parser, option_name, default_value); + else + return get_option(cr, option_name, default_value); + } + +// ---------------------------------------------------------------------------------------- + + template + typename disable_if,T>::type get_option ( + const cr_type& cr, + const parser_type& parser, + const std::string& option_name, + T default_value + ) + { + // make sure requires clause is not broken + DLIB_ASSERT( parser.option_is_defined(option_name) == true && + parser.option(option_name).number_of_arguments() == 1, + "\t T get_option()" + << "\n\t option_name: " << option_name + << "\n\t parser.option_is_defined(option_name): " << parser.option_is_defined(option_name) + << "\n\t parser.option(option_name).number_of_arguments(): " << parser.option(option_name).number_of_arguments() + ); + + if (parser.option(option_name)) + return get_option(parser, option_name, default_value); + else + return get_option(cr, option_name, default_value); + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template + inline std::string get_option ( + const T& cr, + const std::string& option_name, + const char* default_value + ) + { + return get_option(cr, option_name, std::string(default_value)); + } + +// ---------------------------------------------------------------------------------------- + + template + inline std::string get_option ( + const T& parser, + const U& cr, + const std::string& option_name, + const char* default_value + ) + { + return get_option(parser, cr, option_name, std::string(default_value)); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GET_OPTiON_Hh_ + diff --git a/dlib/cmd_line_parser/get_option_abstract.h b/dlib/cmd_line_parser/get_option_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..90dc16721791921d92ca1c2c76fc24bd1cdbc2b5 --- /dev/null +++ b/dlib/cmd_line_parser/get_option_abstract.h @@ -0,0 +1,146 @@ +// Copyright (C) 2012 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_GET_OPTiON_ABSTRACT_Hh_ +#ifdef DLIB_GET_OPTiON_ABSTRACT_Hh_ + +#inclue + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class option_parse_error : public error + { + /*! + WHAT THIS OBJECT REPRESENTS + This is the exception thrown by the get_option() functions. It is + thrown when the option string given by a command line parser or + config reader can't be converted into the type T. + !*/ + }; + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_type, + typename T + > + T get_option ( + const config_reader_type& cr, + const std::string& option_name, + T default_value + ); + /*! + requires + - T is a type which can be read from an input stream + - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h + ensures + - option_name is used to index into the given config_reader. + - if (cr contains an entry corresponding to option_name) then + - converts the string value in cr corresponding to option_name into + an object of type T and returns it. + - else + - returns default_value + - The scheme for indexing into cr based on option_name is best + understood by looking at a few examples: + - an option name of "name" corresponds to cr["name"] + - an option name of "block1.name" corresponds to cr.block("block1")["name"] + - an option name of "block1.block2.name" corresponds to cr.block("block1").block("block2")["name"] + throws + - option_parse_error + This exception is thrown if we attempt but fail to convert the string value + in cr into an object of type T. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename command_line_parser_type, + typename T + > + T get_option ( + const command_line_parser_type& parser, + const std::string& option_name, + T default_value + ); + /*! + requires + - parser.option_is_defined(option_name) == true + - parser.option(option_name).number_of_arguments() == 1 + - T is a type which can be read from an input stream + - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h + ensures + - if (parser.option(option_name)) then + - converts parser.option(option_name).argument() into an object + of type T and returns it. That is, the string argument to this + command line option is converted into a T and returned. + - else + - returns default_value + throws + - option_parse_error + This exception is thrown if we attempt but fail to convert the string + argument into an object of type T. + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename command_line_parser_type, + typename config_reader_type, + typename T + > + T get_option ( + const command_line_parser_type& parser, + const config_reader_type& cr, + const std::string& option_name, + T default_value + ); + /*! + requires + - parser.option_is_defined(option_name) == true + - parser.option(option_name).number_of_arguments() == 1 + - T is a type which can be read from an input stream + - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h + - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h + ensures + - if (parser.option(option_name)) then + - returns get_option(parser, option_name, default_value) + - else + - returns get_option(cr, option_name, default_value) + !*/ + +// ---------------------------------------------------------------------------------------- + + template < + typename command_line_parser_type, + typename config_reader_type, + typename T + > + T get_option ( + const config_reader_type& cr, + const command_line_parser_type& parser, + const std::string& option_name, + T default_value + ); + /*! + requires + - parser.option_is_defined(option_name) == true + - parser.option(option_name).number_of_arguments() == 1 + - T is a type which can be read from an input stream + - command_line_parser_type == an implementation of cmd_line_parser/cmd_line_parser_kernel_abstract.h + - config_reader_type == an implementation of config_reader/config_reader_kernel_abstract.h + ensures + - if (parser.option(option_name)) then + - returns get_option(parser, option_name, default_value) + - else + - returns get_option(cr, option_name, default_value) + !*/ + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_GET_OPTiON_ABSTRACT_Hh_ + + diff --git a/dlib/compress_stream.h b/dlib/compress_stream.h new file mode 100644 index 0000000000000000000000000000000000000000..8ccc1d52faf49b311a24b7b6740d3ec48e553771 --- /dev/null +++ b/dlib/compress_stream.h @@ -0,0 +1,133 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAm_ +#define DLIB_COMPRESS_STREAm_ + +#include "compress_stream/compress_stream_kernel_1.h" +#include "compress_stream/compress_stream_kernel_2.h" +#include "compress_stream/compress_stream_kernel_3.h" + +#include "conditioning_class.h" +#include "entropy_encoder.h" +#include "entropy_decoder.h" + +#include "entropy_encoder_model.h" +#include "entropy_decoder_model.h" +#include "lz77_buffer.h" +#include "sliding_buffer.h" +#include "lzp_buffer.h" +#include "crc32.h" + + +namespace dlib +{ + + class compress_stream + { + compress_stream() {} + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_1b fce1; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_1b fcd1; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_2b fce2; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_2b fcd2; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_3b fce3; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_3b fcd3; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_4a fce4a; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_4a fcd4a; + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_4b fce4b; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_4b fcd4b; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5a fce5a; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5a fcd5a; + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5b fce5b; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5b fcd5b; + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_5c fce5c; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_5c fcd5c; + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_6a fce6; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_6a fcd6; + + + typedef entropy_encoder_model<257,entropy_encoder::kernel_2a>::kernel_2d fce2d; + typedef entropy_decoder_model<257,entropy_decoder::kernel_2a>::kernel_2d fcd2d; + + typedef sliding_buffer::kernel_1a sliding_buffer1; + typedef lz77_buffer::kernel_2a lz77_buffer2a; + + + typedef lzp_buffer::kernel_1a lzp_buf_1; + typedef lzp_buffer::kernel_2a lzp_buf_2; + + + typedef entropy_encoder_model<513,entropy_encoder::kernel_2a>::kernel_1b fce_length; + typedef entropy_decoder_model<513,entropy_decoder::kernel_2a>::kernel_1b fcd_length; + + typedef entropy_encoder_model<65534,entropy_encoder::kernel_2a>::kernel_1b fce_length_2; + typedef entropy_decoder_model<65534,entropy_decoder::kernel_2a>::kernel_1b fcd_length_2; + + + typedef entropy_encoder_model<32257,entropy_encoder::kernel_2a>::kernel_1b fce_index; + typedef entropy_decoder_model<32257,entropy_decoder::kernel_2a>::kernel_1b fcd_index; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef compress_stream_kernel_1 + kernel_1a; + + // kernel_1b + typedef compress_stream_kernel_1 + kernel_1b; + + // kernel_1c + typedef compress_stream_kernel_1 + kernel_1c; + + // kernel_1da + typedef compress_stream_kernel_1 + kernel_1da; + + // kernel_1ea + typedef compress_stream_kernel_1 + kernel_1ea; + + // kernel_1db + typedef compress_stream_kernel_1 + kernel_1db; + + // kernel_1eb + typedef compress_stream_kernel_1 + kernel_1eb; + + // kernel_1ec + typedef compress_stream_kernel_1 + kernel_1ec; + + + + + // kernel_2a + typedef compress_stream_kernel_2 + kernel_2a; + + + + + // kernel_3a + typedef compress_stream_kernel_3 + kernel_3a; + // kernel_3b + typedef compress_stream_kernel_3 + kernel_3b; + + + }; +} + +#endif // DLIB_COMPRESS_STREAm_ + diff --git a/dlib/compress_stream/compress_stream_kernel_1.h b/dlib/compress_stream/compress_stream_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..1a75ec6ced988ec9108379d47d29617e3930b511 --- /dev/null +++ b/dlib/compress_stream/compress_stream_kernel_1.h @@ -0,0 +1,252 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAM_KERNEl_1_ +#define DLIB_COMPRESS_STREAM_KERNEl_1_ + +#include "../algs.h" +#include +#include +#include +#include "compress_stream_kernel_abstract.h" + +namespace dlib +{ + + template < + typename fce, + typename fcd, + typename crc32 + > + class compress_stream_kernel_1 + { + /*! + REQUIREMENTS ON fce + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON fcd + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON crc32 + is an implementation of crc32/crc32_kernel_abstract.h + + + + INITIAL VALUE + this object has no state + + CONVENTION + this object has no state + !*/ + + const static unsigned long eof_symbol = 256; + + public: + + class decompression_error : public dlib::error + { + public: + decompression_error( + const char* i + ) : + dlib::error(std::string(i)) + {} + + decompression_error( + const std::string& i + ) : + dlib::error(i) + {} + }; + + + compress_stream_kernel_1 ( + ) + {} + + ~compress_stream_kernel_1 ( + ) + {} + + void compress ( + std::istream& in, + std::ostream& out + ) const; + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + + private: + + // restricted functions + compress_stream_kernel_1(compress_stream_kernel_1&); // copy constructor + compress_stream_kernel_1& operator=(compress_stream_kernel_1&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename crc32 + > + void compress_stream_kernel_1:: + compress ( + std::istream& in_, + std::ostream& out_ + ) const + { + std::streambuf::int_type temp; + + std::streambuf& in = *in_.rdbuf(); + + typename fce::entropy_encoder_type coder; + coder.set_stream(out_); + + fce model(coder); + + crc32 crc; + + unsigned long count = 0; + + while (true) + { + // write out a known value every 20000 symbols + if (count == 20000) + { + count = 0; + coder.encode(1500,1501,8000); + } + ++count; + + // get the next character + temp = in.sbumpc(); + + // if we have hit EOF then encode the marker symbol + if (temp != EOF) + { + // encode the symbol + model.encode(static_cast(temp)); + crc.add(static_cast(temp)); + continue; + } + else + { + model.encode(eof_symbol); + + // now write the checksum + unsigned long checksum = crc.get_checksum(); + unsigned char byte1 = static_cast((checksum>>24)&0xFF); + unsigned char byte2 = static_cast((checksum>>16)&0xFF); + unsigned char byte3 = static_cast((checksum>>8)&0xFF); + unsigned char byte4 = static_cast((checksum)&0xFF); + + model.encode(byte1); + model.encode(byte2); + model.encode(byte3); + model.encode(byte4); + + break; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename crc32 + > + void compress_stream_kernel_1:: + decompress ( + std::istream& in_, + std::ostream& out_ + ) const + { + + std::streambuf& out = *out_.rdbuf(); + + typename fcd::entropy_decoder_type coder; + coder.set_stream(in_); + + fcd model(coder); + + unsigned long symbol; + unsigned long count = 0; + + crc32 crc; + + // decode until we hit the marker symbol + while (true) + { + // make sure this is the value we expect + if (count == 20000) + { + if (coder.get_target(8000) != 1500) + { + throw decompression_error("Error detected in compressed data stream."); + } + count = 0; + coder.decode(1500,1501); + } + ++count; + + // decode the next symbol + model.decode(symbol); + if (symbol != eof_symbol) + { + crc.add(static_cast(symbol)); + // write this symbol to out + if (out.sputc(static_cast(symbol)) != static_cast(symbol)) + { + throw std::ios::failure("error occurred in compress_stream_kernel_1::decompress"); + } + continue; + } + else + { + // we read eof from the encoded data. now we just have to check the checksum and we are done. + unsigned char byte1; + unsigned char byte2; + unsigned char byte3; + unsigned char byte4; + + model.decode(symbol); byte1 = static_cast(symbol); + model.decode(symbol); byte2 = static_cast(symbol); + model.decode(symbol); byte3 = static_cast(symbol); + model.decode(symbol); byte4 = static_cast(symbol); + + unsigned long checksum = byte1; + checksum <<= 8; + checksum |= byte2; + checksum <<= 8; + checksum |= byte3; + checksum <<= 8; + checksum |= byte4; + + if (checksum != crc.get_checksum()) + throw decompression_error("Error detected in compressed data stream."); + + break; + } + } // while (true) + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_1_ + diff --git a/dlib/compress_stream/compress_stream_kernel_2.h b/dlib/compress_stream/compress_stream_kernel_2.h new file mode 100644 index 0000000000000000000000000000000000000000..e46b23fad874758b2553ac1db1a1006f8ec8b064 --- /dev/null +++ b/dlib/compress_stream/compress_stream_kernel_2.h @@ -0,0 +1,431 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAM_KERNEl_2_ +#define DLIB_COMPRESS_STREAM_KERNEl_2_ + +#include "../algs.h" +#include +#include +#include "compress_stream_kernel_abstract.h" + +namespace dlib +{ + + template < + typename fce, + typename fcd, + typename lz77_buffer, + typename sliding_buffer, + typename fce_length, + typename fcd_length, + typename fce_index, + typename fcd_index, + typename crc32 + > + class compress_stream_kernel_2 + { + /*! + REQUIREMENTS ON fce + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON fcd + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 257. + fce and fcd share the same kernel number. + + REQUIREMENTS ON lz77_buffer + is an implementation of lz77_buffer/lz77_buffer_kernel_abstract.h + + REQUIREMENTS ON sliding_buffer + is an implementation of sliding_buffer/sliding_buffer_kernel_abstract.h + is instantiated with T = unsigned char + + REQUIREMENTS ON fce_length + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 513. This will be used to encode the length of lz77 matches. + fce_length and fcd share the same kernel number. + + REQUIREMENTS ON fcd_length + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 513. This will be used to decode the length of lz77 matches. + fce_length and fcd share the same kernel number. + + REQUIREMENTS ON fce_index + is an implementation of entropy_encoder_model/entropy_encoder_model_kernel_abstract.h + the alphabet_size of fce must be 32257. This will be used to encode the index of lz77 matches. + fce_index and fcd share the same kernel number. + + REQUIREMENTS ON fcd_index + is an implementation of entropy_decoder_model/entropy_decoder_model_kernel_abstract.h + the alphabet_size of fcd must be 32257. This will be used to decode the index of lz77 matches. + fce_index and fcd share the same kernel number. + + REQUIREMENTS ON crc32 + is an implementation of crc32/crc32_kernel_abstract.h + + INITIAL VALUE + this object has no state + + CONVENTION + this object has no state + !*/ + + const static unsigned long eof_symbol = 256; + + public: + + class decompression_error : public dlib::error + { + public: + decompression_error( + const char* i + ) : + dlib::error(std::string(i)) + {} + + decompression_error( + const std::string& i + ) : + dlib::error(i) + {} + }; + + + compress_stream_kernel_2 ( + ) + {} + + ~compress_stream_kernel_2 ( + ) + {} + + void compress ( + std::istream& in, + std::ostream& out + ) const; + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + + private: + + // restricted functions + compress_stream_kernel_2(compress_stream_kernel_2&); // copy constructor + compress_stream_kernel_2& operator=(compress_stream_kernel_2&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename lz77_buffer, + typename sliding_buffer, + typename fce_length, + typename fcd_length, + typename fce_index, + typename fcd_index, + typename crc32 + > + void compress_stream_kernel_2:: + compress ( + std::istream& in_, + std::ostream& out_ + ) const + { + std::streambuf::int_type temp; + + std::streambuf& in = *in_.rdbuf(); + + typename fce::entropy_encoder_type coder; + coder.set_stream(out_); + + fce model(coder); + fce_length model_length(coder); + fce_index model_index(coder); + + const unsigned long LOOKAHEAD_LIMIT = 512; + lz77_buffer buffer(15,LOOKAHEAD_LIMIT); + + crc32 crc; + + + unsigned long count = 0; + + unsigned long lz77_count = 1; // number of times we used lz77 to encode + unsigned long ppm_count = 1; // number of times we used ppm to encode + + + while (true) + { + // write out a known value every 20000 symbols + if (count == 20000) + { + count = 0; + coder.encode(150,151,400); + } + ++count; + + // try to fill the lookahead buffer + if (buffer.get_lookahead_buffer_size() < buffer.get_lookahead_buffer_limit()) + { + temp = in.sbumpc(); + while (temp != EOF) + { + crc.add(static_cast(temp)); + buffer.add(static_cast(temp)); + if (buffer.get_lookahead_buffer_size() == buffer.get_lookahead_buffer_limit()) + break; + temp = in.sbumpc(); + } + } + + // compute the sum of ppm_count and lz77_count but make sure + // it is less than 65536 + unsigned long sum = ppm_count + lz77_count; + if (sum >= 65536) + { + ppm_count >>= 1; + lz77_count >>= 1; + ppm_count |= 1; + lz77_count |= 1; + sum = ppm_count+lz77_count; + } + + // if there are still more symbols in the lookahead buffer to encode + if (buffer.get_lookahead_buffer_size() > 0) + { + unsigned long match_index, match_length; + buffer.find_match(match_index,match_length,6); + if (match_length != 0) + { + + // signal the decoder that we are using lz77 + coder.encode(0,lz77_count,sum); + ++lz77_count; + + // encode the index and length pair + model_index.encode(match_index); + model_length.encode(match_length); + + } + else + { + + // signal the decoder that we are using ppm + coder.encode(lz77_count,sum,sum); + ++ppm_count; + + // encode the symbol using the ppm model + model.encode(buffer.lookahead_buffer(0)); + buffer.shift_buffers(1); + } + } + else + { + // signal the decoder that we are using ppm + coder.encode(lz77_count,sum,sum); + + + model.encode(eof_symbol); + // now write the checksum + unsigned long checksum = crc.get_checksum(); + unsigned char byte1 = static_cast((checksum>>24)&0xFF); + unsigned char byte2 = static_cast((checksum>>16)&0xFF); + unsigned char byte3 = static_cast((checksum>>8)&0xFF); + unsigned char byte4 = static_cast((checksum)&0xFF); + + model.encode(byte1); + model.encode(byte2); + model.encode(byte3); + model.encode(byte4); + + break; + } + } // while (true) + } + +// ---------------------------------------------------------------------------------------- + + template < + typename fce, + typename fcd, + typename lz77_buffer, + typename sliding_buffer, + typename fce_length, + typename fcd_length, + typename fce_index, + typename fcd_index, + typename crc32 + > + void compress_stream_kernel_2:: + decompress ( + std::istream& in_, + std::ostream& out_ + ) const + { + + std::streambuf& out = *out_.rdbuf(); + + typename fcd::entropy_decoder_type coder; + coder.set_stream(in_); + + fcd model(coder); + fcd_length model_length(coder); + fcd_index model_index(coder); + + unsigned long symbol; + unsigned long count = 0; + + sliding_buffer buffer; + buffer.set_size(15); + + // Initialize the buffer to all zeros. There is no algorithmic reason to + // do this. But doing so avoids a warning from valgrind so that is why + // I'm doing this. + for (unsigned long i = 0; i < buffer.size(); ++i) + buffer[i] = 0; + + crc32 crc; + + unsigned long lz77_count = 1; // number of times we used lz77 to encode + unsigned long ppm_count = 1; // number of times we used ppm to encode + bool next_block_lz77; + + + // decode until we hit the marker symbol + while (true) + { + // make sure this is the value we expect + if (count == 20000) + { + if (coder.get_target(400) != 150) + { + throw decompression_error("Error detected in compressed data stream."); + } + count = 0; + coder.decode(150,151); + } + ++count; + + + // compute the sum of ppm_count and lz77_count but make sure + // it is less than 65536 + unsigned long sum = ppm_count + lz77_count; + if (sum >= 65536) + { + ppm_count >>= 1; + lz77_count >>= 1; + ppm_count |= 1; + lz77_count |= 1; + sum = ppm_count+lz77_count; + } + + // check if we are decoding a lz77 or ppm block + if (coder.get_target(sum) < lz77_count) + { + coder.decode(0,lz77_count); + next_block_lz77 = true; + ++lz77_count; + } + else + { + coder.decode(lz77_count,sum); + next_block_lz77 = false; + ++ppm_count; + } + + + if (next_block_lz77) + { + + unsigned long match_length, match_index; + // decode the match index + model_index.decode(match_index); + + // decode the match length + model_length.decode(match_length); + + + match_index += match_length; + buffer.rotate_left(match_length); + for (unsigned long i = 0; i < match_length; ++i) + { + unsigned char ch = buffer[match_index-i]; + buffer[match_length-i-1] = ch; + + crc.add(ch); + // write this ch to out + if (out.sputc(static_cast(ch)) != static_cast(ch)) + { + throw std::ios::failure("error occurred in compress_stream_kernel_2::decompress"); + } + } + + } + else + { + + // decode the next symbol + model.decode(symbol); + if (symbol != eof_symbol) + { + buffer.rotate_left(1); + buffer[0] = static_cast(symbol); + + + crc.add(static_cast(symbol)); + // write this symbol to out + if (out.sputc(static_cast(symbol)) != static_cast(symbol)) + { + throw std::ios::failure("error occurred in compress_stream_kernel_2::decompress"); + } + } + else + { + // this was the eof marker symbol so we are done. now check the checksum + + // now get the checksum and make sure it matches + unsigned char byte1; + unsigned char byte2; + unsigned char byte3; + unsigned char byte4; + + model.decode(symbol); byte1 = static_cast(symbol); + model.decode(symbol); byte2 = static_cast(symbol); + model.decode(symbol); byte3 = static_cast(symbol); + model.decode(symbol); byte4 = static_cast(symbol); + + unsigned long checksum = byte1; + checksum <<= 8; + checksum |= byte2; + checksum <<= 8; + checksum |= byte3; + checksum <<= 8; + checksum |= byte4; + + if (checksum != crc.get_checksum()) + throw decompression_error("Error detected in compressed data stream."); + + break; + } + } + + } // while (true) + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_2_ + diff --git a/dlib/compress_stream/compress_stream_kernel_3.h b/dlib/compress_stream/compress_stream_kernel_3.h new file mode 100644 index 0000000000000000000000000000000000000000..ed4eee290d08a28036ca9a6a49ef1bd6454d7eb5 --- /dev/null +++ b/dlib/compress_stream/compress_stream_kernel_3.h @@ -0,0 +1,381 @@ +// Copyright (C) 2005 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_COMPRESS_STREAM_KERNEl_3_ +#define DLIB_COMPRESS_STREAM_KERNEl_3_ + +#include "../algs.h" +#include "compress_stream_kernel_abstract.h" +#include "../assert.h" + +namespace dlib +{ + + template < + typename lzp_buf, + typename crc32, + unsigned long buffer_size + > + class compress_stream_kernel_3 + { + /*! + REQUIREMENTS ON lzp_buf + is an implementation of lzp_buffer/lzp_buffer_kernel_abstract.h + + REQUIREMENTS ON buffer_size + 10 < buffer_size < 32 + + REQUIREMENTS ON crc32 + is an implementation of crc32/crc32_kernel_abstract.h + + + INITIAL VALUE + this object has no state + + CONVENTION + this object has no state + + + This implementation uses the lzp_buffer and writes out matches + in a byte aligned format. + + !*/ + + + public: + + class decompression_error : public dlib::error + { + public: + decompression_error( + const char* i + ) : + dlib::error(std::string(i)) + {} + + decompression_error( + const std::string& i + ) : + dlib::error(i) + {} + }; + + + compress_stream_kernel_3 ( + ) + { + COMPILE_TIME_ASSERT(10 < buffer_size && buffer_size < 32); + } + + ~compress_stream_kernel_3 ( + ) + {} + + void compress ( + std::istream& in, + std::ostream& out + ) const; + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + + + + private: + + inline void write ( + unsigned char symbol + ) const + { + if (out->sputn(reinterpret_cast(&symbol),1)==0) + throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); + } + + inline void decode ( + unsigned char& symbol, + unsigned char& flag + ) const + { + if (count == 0) + { + if (((size_t)in->sgetn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) + throw decompression_error("Error detected in compressed data stream."); + count = 8; + } + --count; + symbol = buffer[8-count]; + flag = buffer[0] >> 7; + buffer[0] <<= 1; + } + + inline void encode ( + unsigned char symbol, + unsigned char flag + ) const + /*! + requires + - 0 <= flag <= 1 + ensures + - writes symbol with the given one bit flag + !*/ + { + // add this symbol and flag to the buffer + ++count; + buffer[0] <<= 1; + buffer[count] = symbol; + buffer[0] |= flag; + + if (count == 8) + { + if (((size_t)out->sputn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) + throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); + count = 0; + buffer[0] = 0; + } + } + + void clear ( + ) const + /*! + ensures + - resets the buffers + !*/ + { + count = 0; + } + + void flush ( + ) const + /*! + ensures + - flushes any data in the buffers to out + !*/ + { + if (count != 0) + { + buffer[0] <<= (8-count); + if (((size_t)out->sputn(reinterpret_cast(buffer),sizeof(buffer)))!=sizeof(buffer)) + throw std::ios_base::failure("error writing to output stream in compress_stream_kernel_3"); + } + } + + mutable unsigned int count; + // count tells us how many bytes are buffered in buffer and how many flag + // bit are currently in buffer[0] + mutable unsigned char buffer[9]; + // buffer[0] holds the flag bits to be writen. + // the rest of the buffer holds the bytes to be writen. + + mutable std::streambuf* in; + mutable std::streambuf* out; + + // restricted functions + compress_stream_kernel_3(compress_stream_kernel_3&); // copy constructor + compress_stream_kernel_3& operator=(compress_stream_kernel_3&); // assignment operator + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename lzp_buf, + typename crc32, + unsigned long buffer_size + > + void compress_stream_kernel_3:: + compress ( + std::istream& in_, + std::ostream& out_ + ) const + { + in = in_.rdbuf(); + out = out_.rdbuf(); + clear(); + + crc32 crc; + + lzp_buf buffer(buffer_size); + + std::streambuf::int_type temp = in->sbumpc(); + unsigned long index; + unsigned char symbol; + unsigned char length; + + while (temp != EOF) + { + symbol = static_cast(temp); + if (buffer.predict_match(index)) + { + if (buffer[index] == symbol) + { + // this is a match so we must find out how long it is + length = 1; + + buffer.add(symbol); + crc.add(symbol); + + temp = in->sbumpc(); + while (length < 255) + { + if (temp == EOF) + { + break; + } + else if (static_cast(length) >= index) + { + break; + } + else if (static_cast(temp) == buffer[index]) + { + ++length; + buffer.add(static_cast(temp)); + crc.add(static_cast(temp)); + temp = in->sbumpc(); + } + else + { + break; + } + } + + encode(length,1); + } + else + { + // this is also not a match + encode(symbol,0); + buffer.add(symbol); + crc.add(symbol); + + // get the next symbol + temp = in->sbumpc(); + } + } + else + { + // there wasn't a match so just write this symbol + encode(symbol,0); + buffer.add(symbol); + crc.add(symbol); + + // get the next symbol + temp = in->sbumpc(); + } + } + + // use a match of zero length to indicate EOF + encode(0,1); + + // now write the checksum + unsigned long checksum = crc.get_checksum(); + unsigned char byte1 = static_cast((checksum>>24)&0xFF); + unsigned char byte2 = static_cast((checksum>>16)&0xFF); + unsigned char byte3 = static_cast((checksum>>8)&0xFF); + unsigned char byte4 = static_cast((checksum)&0xFF); + + encode(byte1,0); + encode(byte2,0); + encode(byte3,0); + encode(byte4,0); + + flush(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename lzp_buf, + typename crc32, + unsigned long buffer_size + > + void compress_stream_kernel_3:: + decompress ( + std::istream& in_, + std::ostream& out_ + ) const + { + in = in_.rdbuf(); + out = out_.rdbuf(); + clear(); + + crc32 crc; + + lzp_buf buffer(buffer_size); + + + unsigned long index = 0; + unsigned char symbol; + unsigned char length; + unsigned char flag; + + decode(symbol,flag); + while (flag == 0 || symbol != 0) + { + buffer.predict_match(index); + + if (flag == 1) + { + length = symbol; + do + { + --length; + symbol = buffer[index]; + write(symbol); + buffer.add(symbol); + crc.add(symbol); + } while (length != 0); + } + else + { + // this is just a literal + write(symbol); + buffer.add(symbol); + crc.add(symbol); + } + decode(symbol,flag); + } + + + // now get the checksum and make sure it matches + unsigned char byte1; + unsigned char byte2; + unsigned char byte3; + unsigned char byte4; + + decode(byte1,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + decode(byte2,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + decode(byte3,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + decode(byte4,flag); + if (flag != 0) + throw decompression_error("Error detected in compressed data stream."); + + unsigned long checksum = byte1; + checksum <<= 8; + checksum |= byte2; + checksum <<= 8; + checksum |= byte3; + checksum <<= 8; + checksum |= byte4; + + if (checksum != crc.get_checksum()) + throw decompression_error("Error detected in compressed data stream."); + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_3_ + diff --git a/dlib/compress_stream/compress_stream_kernel_abstract.h b/dlib/compress_stream/compress_stream_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..48f46d9e11afac0120c4c56274905f032a199865 --- /dev/null +++ b/dlib/compress_stream/compress_stream_kernel_abstract.h @@ -0,0 +1,94 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ +#ifdef DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ + +#include "../algs.h" +#include + +namespace dlib +{ + + class compress_stream + { + /*! + INITIAL VALUE + This object does not have any state associated with it. + + WHAT THIS OBJECT REPRESENTS + This object consists of the two functions compress and decompress. + These functions allow you to compress and decompress data. + !*/ + + public: + + class decompression_error : public dlib::error {}; + + compress_stream ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + virtual ~compress_stream ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + + void compress ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads all data from in (until EOF is reached) and compresses it + and writes it to out + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + + void decompress ( + std::istream& in, + std::ostream& out + ) const; + /*! + ensures + - reads data from in, decompresses it and writes it to out. note that + it stops reading data from in when it encounters the end of the + compressed data, not when it encounters EOF. + throws + - std::ios_base::failure + if there was a problem writing to out then this exception will + be thrown. + - decompression_error + if an error was detected in the compressed data that prevented + it from being correctly decompressed then this exception is + thrown. + - any other exception + this exception may be thrown if there is any other problem + !*/ + + + private: + + // restricted functions + compress_stream(compress_stream&); // copy constructor + compress_stream& operator=(compress_stream&); // assignment operator + + }; + +} + +#endif // DLIB_COMPRESS_STREAM_KERNEl_ABSTRACT_ + diff --git a/dlib/conditioning_class.h b/dlib/conditioning_class.h new file mode 100644 index 0000000000000000000000000000000000000000..409b9871602ef2942dce9e7c88a5c81db789e9a6 --- /dev/null +++ b/dlib/conditioning_class.h @@ -0,0 +1,80 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASs_ +#define DLIB_CONDITIONING_CLASs_ + +#include "conditioning_class/conditioning_class_kernel_1.h" +#include "conditioning_class/conditioning_class_kernel_2.h" +#include "conditioning_class/conditioning_class_kernel_3.h" +#include "conditioning_class/conditioning_class_kernel_4.h" +#include "conditioning_class/conditioning_class_kernel_c.h" + + +#include "memory_manager.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class + { + conditioning_class() {} + + typedef memory_manager::kernel_2b mm; + + public: + + //----------- kernels --------------- + + // kernel_1a + typedef conditioning_class_kernel_1 + kernel_1a; + typedef conditioning_class_kernel_c + kernel_1a_c; + + // kernel_2a + typedef conditioning_class_kernel_2 + kernel_2a; + typedef conditioning_class_kernel_c + kernel_2a_c; + + // kernel_3a + typedef conditioning_class_kernel_3 + kernel_3a; + typedef conditioning_class_kernel_c + kernel_3a_c; + + + // -------- kernel_4 --------- + + // kernel_4a + typedef conditioning_class_kernel_4 + kernel_4a; + typedef conditioning_class_kernel_c + kernel_4a_c; + + // kernel_4b + typedef conditioning_class_kernel_4 + kernel_4b; + typedef conditioning_class_kernel_c + kernel_4b_c; + + // kernel_4c + typedef conditioning_class_kernel_4 + kernel_4c; + typedef conditioning_class_kernel_c + kernel_4c_c; + + // kernel_4d + typedef conditioning_class_kernel_4 + kernel_4d; + typedef conditioning_class_kernel_c + kernel_4d_c; + + }; +} + +#endif // DLIB_CONDITIONING_CLASS_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_1.h b/dlib/conditioning_class/conditioning_class_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..d26d80244aa76acbe7f25bf6d52c8eaf53fdc047 --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_1.h @@ -0,0 +1,333 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_1_ +#define DLIB_CONDITIONING_CLASS_KERNEl_1_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class_kernel_1 + { + /*! + INITIAL VALUE + total == 1 + counts == pointer to an array of alphabet_size unsigned shorts + for all i except i == alphabet_size-1: counts[i] == 0 + counts[alphabet_size-1] == 1 + + CONVENTION + counts == pointer to an array of alphabet_size unsigned shorts + get_total() == total + get_count(symbol) == counts[symbol] + + LOW_COUNT(symbol) == sum of counts[0] though counts[symbol-1] + or 0 if symbol == 0 + + get_memory_usage() == global_state.memory_usage + !*/ + + public: + + class global_state_type + { + public: + global_state_type () : memory_usage(0) {} + private: + unsigned long memory_usage; + + friend class conditioning_class_kernel_1; + }; + + conditioning_class_kernel_1 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_1 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + + private: + + // restricted functions + conditioning_class_kernel_1(conditioning_class_kernel_1&); // copy constructor + conditioning_class_kernel_1& operator=(conditioning_class_kernel_1&); // assignment operator + + // data members + unsigned short total; + unsigned short* counts; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_1:: + conditioning_class_kernel_1 ( + global_state_type& global_state_ + ) : + total(1), + counts(new unsigned short[alphabet_size]), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + unsigned short* start = counts; + unsigned short* end = counts+alphabet_size-1; + while (start != end) + { + *start = 0; + ++start; + } + *start = 1; + + // update memory usage + global_state.memory_usage += sizeof(unsigned short)*alphabet_size + + sizeof(conditioning_class_kernel_1); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_1:: + ~conditioning_class_kernel_1 ( + ) + { + delete [] counts; + // update memory usage + global_state.memory_usage -= sizeof(unsigned short)*alphabet_size + + sizeof(conditioning_class_kernel_1); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_1:: + clear( + ) + { + total = 1; + unsigned short* start = counts; + unsigned short* end = counts+alphabet_size-1; + while (start != end) + { + *start = 0; + ++start; + } + *start = 1; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + typename conditioning_class_kernel_1::global_state_type& conditioning_class_kernel_1:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + bool conditioning_class_kernel_1:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // if we are going over a total of 65535 then scale down all counts by 2 + if (static_cast(total)+static_cast(amount) >= 65536) + { + total = 0; + unsigned short* start = counts; + unsigned short* end = counts+alphabet_size; + while (start != end) + { + *start >>= 1; + total += *start; + ++start; + } + // make sure it is at least one + if (counts[alphabet_size-1]==0) + { + ++total; + counts[alphabet_size-1] = 1; + } + } + counts[symbol] += amount; + total += amount; + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_count ( + unsigned long symbol + ) const + { + return counts[symbol]; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_1:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (counts[symbol] == 0) + return 0; + + total_count = total; + + const unsigned short* start = counts; + const unsigned short* end = counts+symbol; + unsigned short high_count_temp = *start; + while (start != end) + { + ++start; + high_count_temp += *start; + } + low_count = high_count_temp - *start; + high_count = high_count_temp; + return *start; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_1:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + unsigned long high_count_temp = *counts; + const unsigned short* start = counts; + while (target >= high_count_temp) + { + ++start; + high_count_temp += *start; + } + + low_count = high_count_temp - *start; + high_count = high_count_temp; + symbol = static_cast(start-counts); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_1_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_2.h b/dlib/conditioning_class/conditioning_class_kernel_2.h new file mode 100644 index 0000000000000000000000000000000000000000..c9b38c8e3b95e6b98393ad511a92bb7a62423bc4 --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_2.h @@ -0,0 +1,500 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_2_ +#define DLIB_CONDITIONING_CLASS_KERNEl_2_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class_kernel_2 + { + /*! + INITIAL VALUE + total == 1 + symbols == pointer to array of alphabet_size data structs + for all i except i == alphabet_size-1: symbols[i].count == 0 + symbols[i].left_count == 0 + + symbols[alphabet_size-1].count == 1 + symbols[alpahbet_size-1].left_count == 0 + + CONVENTION + symbols == pointer to array of alphabet_size data structs + get_total() == total + get_count(symbol) == symbols[symbol].count + + symbols is organized as a tree with symbols[0] as the root. + + the left subchild of symbols[i] is symbols[i*2+1] and + the right subchild is symbols[i*2+2]. + the partent of symbols[i] == symbols[(i-1)/2] + + symbols[i].left_count == the sum of the counts of all the + symbols to the left of symbols[i] + + get_memory_usage() == global_state.memory_usage + !*/ + + public: + + class global_state_type + { + public: + global_state_type () : memory_usage(0) {} + private: + unsigned long memory_usage; + + friend class conditioning_class_kernel_2; + }; + + conditioning_class_kernel_2 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_2 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + inline unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + private: + + // restricted functions + conditioning_class_kernel_2(conditioning_class_kernel_2&); // copy constructor + conditioning_class_kernel_2& operator=(conditioning_class_kernel_2&); // assignment operator + + // data members + unsigned short total; + struct data + { + unsigned short count; + unsigned short left_count; + }; + + data* symbols; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_2:: + conditioning_class_kernel_2 ( + global_state_type& global_state_ + ) : + total(1), + symbols(new data[alphabet_size]), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + data* start = symbols; + data* end = symbols + alphabet_size-1; + + while (start != end) + { + start->count = 0; + start->left_count = 0; + ++start; + } + + start->count = 1; + start->left_count = 0; + + + // update the left_counts for the symbol alphabet_size-1 + unsigned short temp; + unsigned long symbol = alphabet_size-1; + while (symbol != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(symbol&0x1); + + // set symbol to its parent + symbol = (symbol-1)>>1; + + // note that all left subchidren are odd and also that + // if symbol was a left subchild then we want to increment + // its parents left_count + if (temp) + ++symbols[symbol].left_count; + } + + global_state.memory_usage += sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_2); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_2:: + ~conditioning_class_kernel_2 ( + ) + { + delete [] symbols; + global_state.memory_usage -= sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_2); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_2:: + clear( + ) + { + data* start = symbols; + data* end = symbols + alphabet_size-1; + + total = 1; + + while (start != end) + { + start->count = 0; + start->left_count = 0; + ++start; + } + + start->count = 1; + start->left_count = 0; + + // update the left_counts + unsigned short temp; + unsigned long symbol = alphabet_size-1; + while (symbol != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(symbol&0x1); + + // set symbol to its parent + symbol = (symbol-1)>>1; + + // note that all left subchidren are odd and also that + // if symbol was a left subchild then we want to increment + // its parents left_count + symbols[symbol].left_count += temp; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + typename conditioning_class_kernel_2::global_state_type& conditioning_class_kernel_2:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + bool conditioning_class_kernel_2:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // if we need to renormalize then do so + if (static_cast(total)+static_cast(amount) >= 65536) + { + unsigned long s; + unsigned short temp; + for (unsigned short i = 0; i < alphabet_size-1; ++i) + { + s = i; + + // divide the count for this symbol by 2 + symbols[i].count >>= 1; + + symbols[i].left_count = 0; + + // bubble this change up though the tree + while (s != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(s&0x1); + + // set s to its parent + s = (s-1)>>1; + + // note that all left subchidren are odd and also that + // if s was a left subchild then we want to increment + // its parents left_count + if (temp) + symbols[s].left_count += symbols[i].count; + } + } + + // update symbols alphabet_size-1 + { + s = alphabet_size-1; + + // divide alphabet_size-1 symbol by 2 if it's > 1 + if (symbols[alphabet_size-1].count > 1) + symbols[alphabet_size-1].count >>= 1; + + // bubble this change up though the tree + while (s != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(s&0x1); + + // set s to its parent + s = (s-1)>>1; + + // note that all left subchidren are odd and also that + // if s was a left subchild then we want to increment + // its parents left_count + if (temp) + symbols[s].left_count += symbols[alphabet_size-1].count; + } + } + + + + + + + // calculate the new total + total = 0; + unsigned long m = 0; + while (m < alphabet_size) + { + total += symbols[m].count + symbols[m].left_count; + m = (m<<1) + 2; + } + + } + + + + + // increment the count for the specified symbol + symbols[symbol].count += amount;; + total += amount; + + + unsigned short temp; + while (symbol != 0) + { + // temp will be 1 if symbol is odd, 0 if it is even + temp = static_cast(symbol&0x1); + + // set symbol to its parent + symbol = (symbol-1)>>1; + + // note that all left subchidren are odd and also that + // if symbol was a left subchild then we want to increment + // its parents left_count + if (temp) + symbols[symbol].left_count += amount; + } + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_count ( + unsigned long symbol + ) const + { + return symbols[symbol].count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_2:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (symbols[symbol].count == 0) + return 0; + + unsigned long current = symbol; + total_count = total; + unsigned long high_count_temp = 0; + bool came_from_right = true; + while (true) + { + if (came_from_right) + { + high_count_temp += symbols[current].count + symbols[current].left_count; + } + + // note that if current is even then it is a right child + came_from_right = !(current&0x1); + + if (current == 0) + break; + + // set current to its parent + current = (current-1)>>1 ; + } + + + low_count = high_count_temp - symbols[symbol].count; + high_count = high_count_temp; + + return symbols[symbol].count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_2:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + unsigned long current = 0; + unsigned long low_count_temp = 0; + + while (true) + { + if (static_cast(target) < symbols[current].left_count) + { + // we should go left + current = (current<<1) + 1; + } + else + { + target -= symbols[current].left_count; + low_count_temp += symbols[current].left_count; + if (static_cast(target) < symbols[current].count) + { + // we have found our target + symbol = current; + high_count = low_count_temp + symbols[current].count; + low_count = low_count_temp; + break; + } + else + { + // go right + target -= symbols[current].count; + low_count_temp += symbols[current].count; + current = (current<<1) + 2; + } + } + + } + + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_1_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_3.h b/dlib/conditioning_class/conditioning_class_kernel_3.h new file mode 100644 index 0000000000000000000000000000000000000000..b6de485558b47bb65f7b32e66bbeabfbbd5fae1d --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_3.h @@ -0,0 +1,438 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_3_ +#define DLIB_CONDITIONING_CLASS_KERNEl_3_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class_kernel_3 + { + /*! + INITIAL VALUE + total == 1 + counts == pointer to an array of alphabet_size data structs + for all i except i == 0: counts[i].count == 0 + counts[0].count == 1 + counts[0].symbol == alphabet_size-1 + for all i except i == alphabet_size-1: counts[i].present == false + counts[alphabet_size-1].present == true + + CONVENTION + counts == pointer to an array of alphabet_size data structs + get_total() == total + get_count(symbol) == counts[x].count where + counts[x].symbol == symbol + + + LOW_COUNT(symbol) == sum of counts[0].count though counts[x-1].count + where counts[x].symbol == symbol + if (counts[0].symbol == symbol) LOW_COUNT(symbol)==0 + + + if (counts[i].count == 0) then + counts[i].symbol == undefined value + + if (symbol has a nonzero count) then + counts[symbol].present == true + + get_memory_usage() == global_state.memory_usage + !*/ + + public: + + class global_state_type + { + public: + global_state_type () : memory_usage(0) {} + private: + unsigned long memory_usage; + + friend class conditioning_class_kernel_3; + }; + + conditioning_class_kernel_3 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_3 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + private: + + // restricted functions + conditioning_class_kernel_3(conditioning_class_kernel_3&); // copy constructor + conditioning_class_kernel_3& operator=(conditioning_class_kernel_3&); // assignment operator + + struct data + { + unsigned short count; + unsigned short symbol; + bool present; + }; + + // data members + unsigned short total; + data* counts; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_3:: + conditioning_class_kernel_3 ( + global_state_type& global_state_ + ) : + total(1), + counts(new data[alphabet_size]), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + data* start = counts; + data* end = counts+alphabet_size; + start->count = 1; + start->symbol = alphabet_size-1; + start->present = false; + ++start; + while (start != end) + { + start->count = 0; + start->present = false; + ++start; + } + counts[alphabet_size-1].present = true; + + // update memory usage + global_state.memory_usage += sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_3); + + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + conditioning_class_kernel_3:: + ~conditioning_class_kernel_3 ( + ) + { + delete [] counts; + // update memory usage + global_state.memory_usage -= sizeof(data)*alphabet_size + + sizeof(conditioning_class_kernel_3); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_3:: + clear( + ) + { + total = 1; + data* start = counts; + data* end = counts+alphabet_size; + start->count = 1; + start->symbol = alphabet_size-1; + start->present = false; + ++start; + while (start != end) + { + start->count = 0; + start->present = false; + ++start; + } + counts[alphabet_size-1].present = true; + + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + typename conditioning_class_kernel_3::global_state_type& conditioning_class_kernel_3:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + bool conditioning_class_kernel_3:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // if we are going over a total of 65535 then scale down all counts by 2 + if (static_cast(total)+static_cast(amount) >= 65536) + { + total = 0; + data* start = counts; + data* end = counts+alphabet_size; + + while (start != end) + { + if (start->count == 1) + { + if (start->symbol == alphabet_size-1) + { + // this symbol must never be zero so we will leave its count at 1 + ++total; + } + else + { + start->count = 0; + counts[start->symbol].present = false; + } + } + else + { + start->count >>= 1; + total += start->count; + } + + ++start; + } + } + + + data* start = counts; + data* swap_spot = counts; + + if (counts[symbol].present) + { + while (true) + { + if (start->symbol == symbol && start->count!=0) + { + unsigned short temp = start->count + amount; + + start->symbol = swap_spot->symbol; + start->count = swap_spot->count; + + swap_spot->symbol = static_cast(symbol); + swap_spot->count = temp; + break; + } + + if ( (start->count) < (swap_spot->count)) + { + swap_spot = start; + } + + + ++start; + } + } + else + { + counts[symbol].present = true; + while (true) + { + if (start->count == 0) + { + start->symbol = swap_spot->symbol; + start->count = swap_spot->count; + + swap_spot->symbol = static_cast(symbol); + swap_spot->count = amount; + break; + } + + if ((start->count) < (swap_spot->count)) + { + swap_spot = start; + } + + ++start; + } + } + + total += amount; + + return true; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_count ( + unsigned long symbol + ) const + { + if (counts[symbol].present == false) + return 0; + + data* start = counts; + while (start->symbol != symbol) + { + ++start; + } + return start->count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + unsigned long conditioning_class_kernel_3:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (counts[symbol].present == false) + return 0; + + total_count = total; + unsigned long low_count_temp = 0; + data* start = counts; + while (start->symbol != symbol) + { + low_count_temp += start->count; + ++start; + } + + low_count = low_count_temp; + high_count = low_count_temp + start->count; + return start->count; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size + > + void conditioning_class_kernel_3:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + unsigned long high_count_temp = counts->count; + const data* start = counts; + while (target >= high_count_temp) + { + ++start; + high_count_temp += start->count; + } + + low_count = high_count_temp - start->count; + high_count = high_count_temp; + symbol = static_cast(start->symbol); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_3_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_4.h b/dlib/conditioning_class/conditioning_class_kernel_4.h new file mode 100644 index 0000000000000000000000000000000000000000..cb48ac196267d23e2341102d0cac173e8ba5a877 --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_4.h @@ -0,0 +1,533 @@ +// Copyright (C) 2004 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_4_ +#define DLIB_CONDITIONING_CLASS_KERNEl_4_ + +#include "conditioning_class_kernel_abstract.h" +#include "../assert.h" +#include "../algs.h" + +namespace dlib +{ + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + class conditioning_class_kernel_4 + { + /*! + REQUIREMENTS ON pool_size + pool_size > 0 + this will be the number of nodes contained in our memory pool + + REQUIREMENTS ON mem_manager + mem_manager is an implementation of memory_manager/memory_manager_kernel_abstract.h + + INITIAL VALUE + total == 1 + escapes == 1 + next == 0 + + CONVENTION + get_total() == total + get_count(alphabet_size-1) == escapes + + if (next != 0) then + next == pointer to the start of a linked list and the linked list + is terminated by a node with a next pointer of 0. + + get_count(symbol) == node::count for the node where node::symbol==symbol + or 0 if no such node currently exists. + + if (there is a node for the symbol) then + LOW_COUNT(symbol) == the sum of all node's counts in the linked list + up to but not including the node for the symbol. + + get_memory_usage() == global_state.memory_usage + !*/ + + + struct node + { + unsigned short symbol; + unsigned short count; + node* next; + }; + + public: + + class global_state_type + { + public: + global_state_type ( + ) : + memory_usage(pool_size*sizeof(node)+sizeof(global_state_type)) + {} + private: + unsigned long memory_usage; + + typename mem_manager::template rebind::other pool; + + friend class conditioning_class_kernel_4; + }; + + conditioning_class_kernel_4 ( + global_state_type& global_state_ + ); + + ~conditioning_class_kernel_4 ( + ); + + void clear( + ); + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + inline unsigned long get_total ( + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + + unsigned long get_memory_usage ( + ) const; + + global_state_type& get_global_state ( + ); + + static unsigned long get_alphabet_size ( + ); + + + private: + + void half_counts ( + ); + /*! + ensures + - divides all counts by 2 but ensures that escapes is always at least 1 + !*/ + + // restricted functions + conditioning_class_kernel_4(conditioning_class_kernel_4&); // copy constructor + conditioning_class_kernel_4& operator=(conditioning_class_kernel_4&); // assignment operator + + // data members + unsigned short total; + unsigned short escapes; + node* next; + global_state_type& global_state; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + conditioning_class_kernel_4:: + conditioning_class_kernel_4 ( + global_state_type& global_state_ + ) : + total(1), + escapes(1), + next(0), + global_state(global_state_) + { + COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); + + // update memory usage + global_state.memory_usage += sizeof(conditioning_class_kernel_4); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + conditioning_class_kernel_4:: + ~conditioning_class_kernel_4 ( + ) + { + clear(); + // update memory usage + global_state.memory_usage -= sizeof(conditioning_class_kernel_4); + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + void conditioning_class_kernel_4:: + clear( + ) + { + total = 1; + escapes = 1; + while (next) + { + node* temp = next; + next = next->next; + global_state.pool.deallocate(temp); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_memory_usage( + ) const + { + return global_state.memory_usage; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + typename conditioning_class_kernel_4::global_state_type& conditioning_class_kernel_4:: + get_global_state( + ) + { + return global_state; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + bool conditioning_class_kernel_4:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + if (symbol == alphabet_size-1) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + escapes += amount; + total += amount; + return true; + } + + + // find the symbol and increment it or add a new node to the list + if (next) + { + node* temp = next; + node* previous = 0; + while (true) + { + if (temp->symbol == static_cast(symbol)) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + // we have found the symbol + total += amount; + temp->count += amount; + + // if this node now has a count greater than its parent node + if (previous && temp->count > previous->count) + { + // swap the nodes so that the nodes will be in semi-sorted order + swap(temp->count,previous->count); + swap(temp->symbol,previous->symbol); + } + return true; + } + else if (temp->next == 0) + { + // we did not find the symbol so try to add it to the list + if (global_state.pool.get_number_of_allocations() < pool_size) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + node* t = global_state.pool.allocate(); + t->next = 0; + t->symbol = static_cast(symbol); + t->count = amount; + temp->next = t; + total += amount; + return true; + } + else + { + // no memory left + return false; + } + } + else if (temp->count == 0) + { + // remove nodes that have a zero count + if (previous) + { + previous->next = temp->next; + node* t = temp; + temp = temp->next; + global_state.pool.deallocate(t); + } + else + { + next = temp->next; + node* t = temp; + temp = temp->next; + global_state.pool.deallocate(t); + } + } + else + { + previous = temp; + temp = temp->next; + } + } // while (true) + } + // if there aren't any nodes in the list yet then do this instead + else + { + if (global_state.pool.get_number_of_allocations() < pool_size) + { + // make sure we won't cause any overflow + if (total >= 65536 - amount ) + half_counts(); + + next = global_state.pool.allocate(); + next->next = 0; + next->symbol = static_cast(symbol); + next->count = amount; + total += amount; + return true; + } + else + { + // no memory left + return false; + } + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_count ( + unsigned long symbol + ) const + { + if (symbol == alphabet_size-1) + { + return escapes; + } + else + { + node* temp = next; + while (temp) + { + if (temp->symbol == symbol) + return temp->count; + temp = temp->next; + } + return 0; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_alphabet_size ( + ) + { + return alphabet_size; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_total ( + ) const + { + return total; + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + unsigned long conditioning_class_kernel_4:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + if (symbol != alphabet_size-1) + { + node* temp = next; + unsigned long low = 0; + while (temp) + { + if (temp->symbol == static_cast(symbol)) + { + high_count = temp->count + low; + low_count = low; + total_count = total; + return temp->count; + } + low += temp->count; + temp = temp->next; + } + return 0; + } + else + { + total_count = total; + high_count = total; + low_count = total-escapes; + return escapes; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + void conditioning_class_kernel_4:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + node* temp = next; + unsigned long high = 0; + while (true) + { + if (temp != 0) + { + high += temp->count; + if (target < high) + { + symbol = temp->symbol; + high_count = high; + low_count = high - temp->count; + return; + } + temp = temp->next; + } + else + { + // this must be the escape symbol + symbol = alphabet_size-1; + low_count = total-escapes; + high_count = total; + return; + } + } + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // private member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + + template < + unsigned long alphabet_size, + unsigned long pool_size, + typename mem_manager + > + void conditioning_class_kernel_4:: + half_counts ( + ) + { + total = 0; + if (escapes > 1) + escapes >>= 1; + + //divide all counts by 2 + node* temp = next; + while (temp) + { + temp->count >>= 1; + total += temp->count; + temp = temp->next; + } + total += escapes; + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_4_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_abstract.h b/dlib/conditioning_class/conditioning_class_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..411aea56653a471dbc9226a718aa061e4078f78d --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_abstract.h @@ -0,0 +1,228 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ +#ifdef DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ + +#include "../algs.h" + +namespace dlib +{ + + template < + unsigned long alphabet_size + > + class conditioning_class + { + /*! + REQUIREMENTS ON alphabet_size + 1 < alphabet_size < 65536 + + INITIAL VALUE + get_total() == 1 + get_count(X) == 0 : for all valid values of X except alphabet_size-1 + get_count(alphabet_size-1) == 1 + + WHAT THIS OBJECT REPRESENTS + This object represents a conditioning class used for arithmetic style + compression. It maintains the cumulative counts which are needed + by the entropy_coder and entropy_decoder objects. + + At any moment a conditioning_class object represents a set of + alphabet_size symbols. Each symbol is associated with an integer + called its count. + + All symbols start out with a count of zero except for alphabet_size-1. + This last symbol will always have a count of at least one. It is + intended to be used as an escape into a lower context when coding + and so it must never have a zero probability or the decoder won't + be able to identify the escape symbol. + + NOTATION: + Let MAP(i) be a function which maps integers to symbols. MAP(i) is + one to one and onto. Its domain is 1 to alphabet_size inclusive. + + Let RMAP(s) be the inverse of MAP(i). + ( i.e. RMAP(MAP(i)) == i and MAP(RMAP(s)) == s ) + + Let COUNT(i) give the count for the symbol MAP(i). + ( i.e. COUNT(i) == get_count(MAP(i)) ) + + + Let LOW_COUNT(s) == the sum of COUNT(x) for x == 1 to x == RMAP(s)-1 + (note that the sum of COUNT(x) for x == 1 to x == 0 is 0) + Let HIGH_COUNT(s) == LOW_COUNT(s) + get_count(s) + + + + Basically what this is saying is just that you shoudln't assume you know + what order the symbols are placed in when calculating the cumulative + sums. The specific mapping provided by the MAP() function is unspecified. + + THREAD SAFETY + This object can be used safely in a multithreaded program as long as the + global state is not shared between conditioning classes which run on + different threads. + + GLOBAL_STATE_TYPE + The global_state_type obejct allows instances of the conditioning_class + object to share any kind of global state the implementer desires. + However, the global_state_type object exists primarily to facilitate the + sharing of a memory pool between many instances of a conditioning_class + object. But note that it is not required that there be any kind of + memory pool at all, it is just a possibility. + !*/ + + public: + + class global_state_type + { + global_state_type ( + ); + /*! + ensures + - #*this is properly initialized + throws + - std::bad_alloc + !*/ + + // my contents are implementation specific. + }; + + conditioning_class ( + global_state_type& global_state + ); + /*! + ensures + - #*this is properly initialized + - &#get_global_state() == &global_state + throws + - std::bad_alloc + !*/ + + ~conditioning_class ( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + !*/ + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + /*! + requires + - 0 <= symbol < alphabet_size + - 0 < amount < 32768 + ensures + - if (sufficient memory is available to complete this operation) then + - returns true + - if (get_total()+amount < 65536) then + - #get_count(symbol) == get_count(symbol) + amount + - else + - #get_count(symbol) == get_count(symbol)/2 + amount + - if (get_count(alphabet_size-1) == 1) then + - #get_count(alphabet_size-1) == 1 + - else + - #get_count(alphabet_size-1) == get_count(alphabet_size-1)/2 + - for all X where (X != symbol)&&(X != alpahbet_size-1): + #get_count(X) == get_count(X)/2 + - else + - returns false + !*/ + + unsigned long get_count ( + unsigned long symbol + ) const; + /*! + requires + - 0 <= symbol < alphabet_size + ensures + - returns the count for the specified symbol + !*/ + + unsigned long get_total ( + ) const; + /*! + ensures + - returns the sum of get_count(X) for all valid values of X + (i.e. returns the sum of the counts for all the symbols) + !*/ + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + /*! + requires + - 0 <= symbol < alphabet_size + ensures + - returns get_count(symbol) + - if (get_count(symbol) != 0) then + - #total_count == get_total() + - #low_count == LOW_COUNT(symbol) + - #high_count == HIGH_COUNT(symbol) + - #low_count < #high_count <= #total_count + !*/ + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + /*! + requires + - 0 <= target < get_total() + ensures + - LOW_COUNT(#symbol) <= target < HIGH_COUNT(#symbol) + - #low_count == LOW_COUNT(#symbol) + - #high_count == HIGH_COUNT(#symbol) + - #low_count < #high_count <= get_total() + !*/ + + global_state_type& get_global_state ( + ); + /*! + ensures + - returns a reference to the global state used by *this + !*/ + + unsigned long get_memory_usage ( + ) const; + /*! + ensures + - returns the number of bytes of memory allocated by all conditioning_class + objects that share the global state given by get_global_state() + !*/ + + static unsigned long get_alphabet_size ( + ); + /*! + ensures + - returns alphabet_size + !*/ + + private: + + // restricted functions + conditioning_class(conditioning_class&); // copy constructor + conditioning_class& operator=(conditioning_class&); // assignment operator + + }; + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_ABSTRACT_ + diff --git a/dlib/conditioning_class/conditioning_class_kernel_c.h b/dlib/conditioning_class/conditioning_class_kernel_c.h new file mode 100644 index 0000000000000000000000000000000000000000..964240be862a3bc5b5ed769bd1cf0ce9f4288700 --- /dev/null +++ b/dlib/conditioning_class/conditioning_class_kernel_c.h @@ -0,0 +1,162 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONDITIONING_CLASS_KERNEl_C_ +#define DLIB_CONDITIONING_CLASS_KERNEl_C_ + +#include "conditioning_class_kernel_abstract.h" +#include "../algs.h" +#include "../assert.h" +#include + +namespace dlib +{ + + template < + typename cc_base + > + class conditioning_class_kernel_c : public cc_base + { + const unsigned long alphabet_size; + + public: + + conditioning_class_kernel_c ( + typename cc_base::global_state_type& global_state + ) : cc_base(global_state),alphabet_size(cc_base::get_alphabet_size()) {} + + bool increment_count ( + unsigned long symbol, + unsigned short amount = 1 + ); + + unsigned long get_count ( + unsigned long symbol + ) const; + + unsigned long get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const; + + void get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const; + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + bool conditioning_class_kernel_c:: + increment_count ( + unsigned long symbol, + unsigned short amount + ) + { + // make sure requires clause is not broken + DLIB_CASSERT(symbol < alphabet_size && + 0 < amount && amount < 32768, + "\tvoid conditioning_class::increment_count()" + << "\n\tthe symbol must be in the range 0 to alphabet_size-1. and" + << "\n\tamount must be in the range 1 to 32767" + << "\n\talphabet_size: " << alphabet_size + << "\n\tsymbol: " << symbol + << "\n\tamount: " << amount + << "\n\tthis: " << this + ); + + // call the real function + return cc_base::increment_count(symbol,amount); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + unsigned long conditioning_class_kernel_c:: + get_count ( + unsigned long symbol + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(symbol < alphabet_size, + "\tvoid conditioning_class::get_count()" + << "\n\tthe symbol must be in the range 0 to alphabet_size-1" + << "\n\talphabet_size: " << alphabet_size + << "\n\tsymbol: " << symbol + << "\n\tthis: " << this + ); + + // call the real function + return cc_base::get_count(symbol); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + unsigned long conditioning_class_kernel_c:: + get_range ( + unsigned long symbol, + unsigned long& low_count, + unsigned long& high_count, + unsigned long& total_count + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT(symbol < alphabet_size, + "\tvoid conditioning_class::get_range()" + << "\n\tthe symbol must be in the range 0 to alphabet_size-1" + << "\n\talphabet_size: " << alphabet_size + << "\n\tsymbol: " << symbol + << "\n\tthis: " << this + ); + + // call the real function + return cc_base::get_range(symbol,low_count,high_count,total_count); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename cc_base + > + void conditioning_class_kernel_c:: + get_symbol ( + unsigned long target, + unsigned long& symbol, + unsigned long& low_count, + unsigned long& high_count + ) const + { + // make sure requires clause is not broken + DLIB_CASSERT( target < this->get_total(), + "\tvoid conditioning_class::get_symbol()" + << "\n\tthe target must be in the range 0 to get_total()-1" + << "\n\tget_total(): " << this->get_total() + << "\n\ttarget: " << target + << "\n\tthis: " << this + ); + + // call the real function + cc_base::get_symbol(target,symbol,low_count,high_count); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONDITIONING_CLASS_KERNEl_C_ + diff --git a/dlib/config.h b/dlib/config.h new file mode 100644 index 0000000000000000000000000000000000000000..f4e7dda4f27680152839c43df56049a789f5d41e --- /dev/null +++ b/dlib/config.h @@ -0,0 +1,31 @@ + + +// If you are compiling dlib as a shared library and installing it somewhere on your system +// then it is important that any programs that use dlib agree on the state of the +// DLIB_ASSERT statements (i.e. they are either always on or always off). Therefore, +// uncomment one of the following lines to force all DLIB_ASSERTs to either always on or +// always off. If you don't define one of these two macros then DLIB_ASSERT will toggle +// automatically depending on the state of certain other macros, which is not what you want +// when creating a shared library. +//#define ENABLE_ASSERTS // asserts always enabled +//#define DLIB_DISABLE_ASSERTS // asserts always disabled + +//#define DLIB_ISO_CPP_ONLY +//#define DLIB_NO_GUI_SUPPORT +//#define DLIB_ENABLE_STACK_TRACE + +// You should also consider telling dlib to link against libjpeg, libpng, libgif, fftw, CUDA, +// and a BLAS and LAPACK library. To do this you need to uncomment the following #defines. +// #define DLIB_JPEG_SUPPORT +// #define DLIB_PNG_SUPPORT +// #define DLIB_GIF_SUPPORT +// #define DLIB_USE_FFTW +// #define DLIB_USE_BLAS +// #define DLIB_USE_LAPACK +// #define DLIB_USE_ROCM + + +// Define this so the code in dlib/test_for_odr_violations.h can detect ODR violations +// related to users doing bad things with config.h +#define DLIB_NOT_CONFIGURED + diff --git a/dlib/config.h.in b/dlib/config.h.in new file mode 100644 index 0000000000000000000000000000000000000000..ce324aeb7fa0aaac199e10bd9b916ac2eb520728 --- /dev/null +++ b/dlib/config.h.in @@ -0,0 +1,36 @@ + + +// If you are compiling dlib as a shared library and installing it somewhere on your system +// then it is important that any programs that use dlib agree on the state of the +// DLIB_ASSERT statements (i.e. they are either always on or always off). Therefore, +// uncomment one of the following lines to force all DLIB_ASSERTs to either always on or +// always off. If you don't define one of these two macros then DLIB_ASSERT will toggle +// automatically depending on the state of certain other macros, which is not what you want +// when creating a shared library. +#cmakedefine ENABLE_ASSERTS // asserts always enabled +#cmakedefine DLIB_DISABLE_ASSERTS // asserts always disabled + +#cmakedefine DLIB_ISO_CPP_ONLY +#cmakedefine DLIB_NO_GUI_SUPPORT +#cmakedefine DLIB_ENABLE_STACK_TRACE + +#cmakedefine LAPACK_FORCE_UNDERSCORE +#cmakedefine LAPACK_FORCE_NOUNDERSCORE + +// You should also consider telling dlib to link against libjpeg, libpng, libgif, fftw, CUDA, +// and a BLAS and LAPACK library. To do this you need to uncomment the following #defines. +#cmakedefine DLIB_JPEG_SUPPORT +#cmakedefine DLIB_WEBP_SUPPORT +#cmakedefine DLIB_PNG_SUPPORT +#cmakedefine DLIB_GIF_SUPPORT +#cmakedefine DLIB_USE_FFTW +#cmakedefine DLIB_USE_BLAS +#cmakedefine DLIB_USE_LAPACK +#cmakedefine DLIB_USE_ROCM +#cmakedefine DLIB_USE_MKL_FFT +#cmakedefine DLIB_USE_FFMPEG + +// This variable allows dlib/test_for_odr_violations.h to catch people who mistakenly use +// headers from one version of dlib with a compiled dlib binary from a different dlib version. +#cmakedefine DLIB_CHECK_FOR_VERSION_MISMATCH @DLIB_CHECK_FOR_VERSION_MISMATCH@ + diff --git a/dlib/config_reader.h b/dlib/config_reader.h new file mode 100644 index 0000000000000000000000000000000000000000..d140a310ce2ddb9e7ace1c3a7abaf18d80ffb6e8 --- /dev/null +++ b/dlib/config_reader.h @@ -0,0 +1,39 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONFIG_READEr_ +#define DLIB_CONFIG_READEr_ + +#include "config_reader/config_reader_kernel_1.h" +#include "map.h" +#include "tokenizer.h" +#include "cmd_line_parser/get_option.h" + +#include "algs.h" +#include "is_kind.h" + + +namespace dlib +{ + + typedef config_reader_kernel_1< + map::kernel_1b, + map::kernel_1b, + tokenizer::kernel_1a + > config_reader; + + template <> struct is_config_reader { const static bool value = true; }; + +#ifndef DLIB_ISO_CPP_ONLY + typedef config_reader_thread_safe_1< + config_reader, + map::kernel_1b + > config_reader_thread_safe; + + template <> struct is_config_reader { const static bool value = true; }; +#endif // DLIB_ISO_CPP_ONLY + + +} + +#endif // DLIB_CONFIG_READEr_ + diff --git a/dlib/config_reader/config_reader_kernel_1.h b/dlib/config_reader/config_reader_kernel_1.h new file mode 100644 index 0000000000000000000000000000000000000000..c0f9e5a7110eb5743183642964d8849bda3d4b53 --- /dev/null +++ b/dlib/config_reader/config_reader_kernel_1.h @@ -0,0 +1,738 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONFIG_READER_KERNEl_1_ +#define DLIB_CONFIG_READER_KERNEl_1_ + +#include "config_reader_kernel_abstract.h" +#include +#include +#include +#include +#include "../algs.h" +#include "../stl_checked/std_vector_c.h" + +#ifndef DLIB_ISO_CPP_ONLY +#include "config_reader_thread_safe_1.h" +#endif + +namespace dlib +{ + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + class config_reader_kernel_1 + { + + /*! + REQUIREMENTS ON map_string_string + is an implementation of map/map_kernel_abstract.h that maps std::string to std::string + + REQUIREMENTS ON map_string_void + is an implementation of map/map_kernel_abstract.h that maps std::string to void* + + REQUIREMENTS ON tokenizer + is an implementation of tokenizer/tokenizer_kernel_abstract.h + + CONVENTION + key_table.is_in_domain(x) == is_key_defined(x) + block_table.is_in_domain(x) == is_block_defined(x) + + key_table[x] == operator[](x) + block_table[x] == (void*)&block(x) + !*/ + + public: + + // These two typedefs are defined for backwards compatibility with older versions of dlib. + typedef config_reader_kernel_1 kernel_1a; +#ifndef DLIB_ISO_CPP_ONLY + typedef config_reader_thread_safe_1< + config_reader_kernel_1, + map_string_void + > thread_safe_1a; +#endif // DLIB_ISO_CPP_ONLY + + + config_reader_kernel_1(); + + class config_reader_error : public dlib::error + { + friend class config_reader_kernel_1; + config_reader_error( + unsigned long ln, + bool r = false + ) : + dlib::error(ECONFIG_READER), + line_number(ln), + redefinition(r) + { + std::ostringstream sout; + sout << "Error in config_reader while parsing at line number " << line_number << "."; + if (redefinition) + sout << "\nThe identifier on this line has already been defined in this scope."; + const_cast(info) = sout.str(); + } + public: + const unsigned long line_number; + const bool redefinition; + }; + + class file_not_found : public dlib::error + { + friend class config_reader_kernel_1; + file_not_found( + const std::string& file_name_ + ) : + dlib::error(ECONFIG_READER, "Error in config_reader, unable to open file " + file_name_), + file_name(file_name_) + {} + + ~file_not_found() throw() {} + + public: + const std::string file_name; + }; + + class config_reader_access_error : public dlib::error + { + public: + config_reader_access_error( + const std::string& block_name_, + const std::string& key_name_ + ) : + dlib::error(ECONFIG_READER), + block_name(block_name_), + key_name(key_name_) + { + std::ostringstream sout; + sout << "Error in config_reader.\n"; + if (block_name.size() > 0) + sout << " A block with the name '" << block_name << "' was expected but not found."; + else if (key_name.size() > 0) + sout << " A key with the name '" << key_name << "' was expected but not found."; + + const_cast(info) = sout.str(); + } + + ~config_reader_access_error() throw() {} + const std::string block_name; + const std::string key_name; + }; + + config_reader_kernel_1( + const std::string& config_file + ); + + config_reader_kernel_1( + std::istream& in + ); + + virtual ~config_reader_kernel_1( + ); + + void clear ( + ); + + void load_from ( + std::istream& in + ); + + void load_from ( + const std::string& config_file + ); + + bool is_key_defined ( + const std::string& key + ) const; + + bool is_block_defined ( + const std::string& name + ) const; + + typedef config_reader_kernel_1 this_type; + const this_type& block ( + const std::string& name + ) const; + + const std::string& operator[] ( + const std::string& key + ) const; + + template < + typename queue_of_strings + > + void get_keys ( + queue_of_strings& keys + ) const; + + template < + typename alloc + > + void get_keys ( + std::vector& keys + ) const; + + template < + typename alloc + > + void get_keys ( + std_vector_c& keys + ) const; + + template < + typename queue_of_strings + > + void get_blocks ( + queue_of_strings& blocks + ) const; + + template < + typename alloc + > + void get_blocks ( + std::vector& blocks + ) const; + + template < + typename alloc + > + void get_blocks ( + std_vector_c& blocks + ) const; + + private: + + static void parse_config_file ( + config_reader_kernel_1& cr, + tokenizer& tok, + unsigned long& line_number, + const bool top_of_recursion = true + ); + /*! + requires + - line_number == 1 + - cr == *this + - top_of_recursion == true + ensures + - parses the data coming from tok and puts it into cr. + throws + - config_reader_error + !*/ + + map_string_string key_table; + map_string_void block_table; + + // restricted functions + config_reader_kernel_1(config_reader_kernel_1&); + config_reader_kernel_1& operator=(config_reader_kernel_1&); + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + config_reader_kernel_1( + ) + { + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + clear( + ) + { + // free all our blocks + block_table.reset(); + while (block_table.move_next()) + { + delete static_cast(block_table.element().value()); + } + block_table.clear(); + key_table.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + load_from( + std::istream& in + ) + { + clear(); + + tokenizer tok; + tok.set_stream(in); + tok.set_identifier_token( + tok.lowercase_letters() + tok.uppercase_letters(), + tok.lowercase_letters() + tok.uppercase_letters() + tok.numbers() + "_-." + ); + + unsigned long line_number = 1; + try + { + parse_config_file(*this,tok,line_number); + } + catch (...) + { + clear(); + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + load_from( + const std::string& config_file + ) + { + clear(); + std::ifstream fin(config_file.c_str()); + if (!fin) + throw file_not_found(config_file); + + load_from(fin); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + config_reader_kernel_1( + std::istream& in + ) + { + load_from(in); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + config_reader_kernel_1( + const std::string& config_file + ) + { + load_from(config_file); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + void config_reader_kernel_1:: + parse_config_file( + config_reader_kernel_1& cr, + tokenizer& tok, + unsigned long& line_number, + const bool top_of_recursion + ) + { + int type; + std::string token; + bool in_comment = false; + bool seen_identifier = false; + std::string identifier; + while (true) + { + tok.get_token(type,token); + // ignore white space + if (type == tokenizer::WHITE_SPACE) + continue; + + // basically ignore end of lines + if (type == tokenizer::END_OF_LINE) + { + ++line_number; + in_comment = false; + continue; + } + + // we are in a comment still so ignore this + if (in_comment) + continue; + + // if this is the start of a comment + if (type == tokenizer::CHAR && token[0] == '#') + { + in_comment = true; + continue; + } + + // if this is the case then we have just finished parsing a block so we should + // quit this function + if ( (type == tokenizer::CHAR && token[0] == '}' && !top_of_recursion) || + (type == tokenizer::END_OF_FILE && top_of_recursion) ) + { + break; + } + + if (seen_identifier) + { + seen_identifier = false; + // the next character should be either a '=' or a '{' + if (type != tokenizer::CHAR || (token[0] != '=' && token[0] != '{')) + throw config_reader_error(line_number); + + if (token[0] == '=') + { + // we should parse the value out now + // first discard any white space + if (tok.peek_type() == tokenizer::WHITE_SPACE) + tok.get_token(type,token); + + std::string value; + type = tok.peek_type(); + token = tok.peek_token(); + while (true) + { + if (type == tokenizer::END_OF_FILE || type == tokenizer::END_OF_LINE) + break; + + if (type == tokenizer::CHAR && token[0] == '\\') + { + tok.get_token(type,token); + if (tok.peek_type() == tokenizer::CHAR && + tok.peek_token()[0] == '#') + { + tok.get_token(type,token); + value += '#'; + } + else if (tok.peek_type() == tokenizer::CHAR && + tok.peek_token()[0] == '}') + { + tok.get_token(type,token); + value += '}'; + } + else + { + value += '\\'; + } + } + else if (type == tokenizer::CHAR && + (token[0] == '#' || token[0] == '}')) + { + break; + } + else + { + value += token; + tok.get_token(type,token); + } + type = tok.peek_type(); + token = tok.peek_token(); + } // while(true) + + // strip of any tailing white space from value + std::string::size_type pos = value.find_last_not_of(" \t\r\n"); + if (pos == std::string::npos) + value.clear(); + else + value.erase(pos+1); + + // make sure this key isn't already in the key_table + if (cr.key_table.is_in_domain(identifier)) + throw config_reader_error(line_number,true); + + // add this key/value pair to the key_table + cr.key_table.add(identifier,value); + + } + else // when token[0] == '{' + { + // make sure this identifier isn't already in the block_table + if (cr.block_table.is_in_domain(identifier)) + throw config_reader_error(line_number,true); + + config_reader_kernel_1* new_cr = new config_reader_kernel_1; + void* vtemp = new_cr; + try { cr.block_table.add(identifier,vtemp); } + catch (...) { delete new_cr; throw; } + + // now parse this block + parse_config_file(*new_cr,tok,line_number,false); + } + } + else + { + // the next thing should be an identifier but if it isn't this is an error + if (type != tokenizer::IDENTIFIER) + throw config_reader_error(line_number); + + seen_identifier = true; + identifier = token; + } + } // while (true) + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + config_reader_kernel_1:: + ~config_reader_kernel_1( + ) + { + clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + bool config_reader_kernel_1:: + is_key_defined ( + const std::string& key + ) const + { + return key_table.is_in_domain(key); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + bool config_reader_kernel_1:: + is_block_defined ( + const std::string& name + ) const + { + return block_table.is_in_domain(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename mss, + typename msv, + typename tokenizer + > + const config_reader_kernel_1& config_reader_kernel_1:: + block ( + const std::string& name + ) const + { + if (is_block_defined(name) == false) + { + throw config_reader_access_error(name,""); + } + + return *static_cast(block_table[name]); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + const std::string& config_reader_kernel_1:: + operator[] ( + const std::string& key + ) const + { + if (is_key_defined(key) == false) + { + throw config_reader_access_error("",key); + } + + return key_table[key]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename queue_of_strings + > + void config_reader_kernel_1:: + get_keys ( + queue_of_strings& keys + ) const + { + keys.clear(); + key_table.reset(); + std::string temp; + while (key_table.move_next()) + { + temp = key_table.element().key(); + keys.enqueue(temp); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_keys ( + std::vector& keys + ) const + { + keys.clear(); + key_table.reset(); + while (key_table.move_next()) + { + keys.push_back(key_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_keys ( + std_vector_c& keys + ) const + { + keys.clear(); + key_table.reset(); + while (key_table.move_next()) + { + keys.push_back(key_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename queue_of_strings + > + void config_reader_kernel_1:: + get_blocks ( + queue_of_strings& blocks + ) const + { + blocks.clear(); + block_table.reset(); + std::string temp; + while (block_table.move_next()) + { + temp = block_table.element().key(); + blocks.enqueue(temp); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_blocks ( + std::vector& blocks + ) const + { + blocks.clear(); + block_table.reset(); + while (block_table.move_next()) + { + blocks.push_back(block_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename map_string_string, + typename map_string_void, + typename tokenizer + > + template < + typename alloc + > + void config_reader_kernel_1:: + get_blocks ( + std_vector_c& blocks + ) const + { + blocks.clear(); + block_table.reset(); + while (block_table.move_next()) + { + blocks.push_back(block_table.element().key()); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONFIG_READER_KERNEl_1_ + diff --git a/dlib/config_reader/config_reader_kernel_abstract.h b/dlib/config_reader/config_reader_kernel_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..e8c44c2b2f965249b382e2a3708096f69f4bc68a --- /dev/null +++ b/dlib/config_reader/config_reader_kernel_abstract.h @@ -0,0 +1,363 @@ +// Copyright (C) 2003 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CONFIG_READER_KERNEl_ABSTRACT_ +#ifdef DLIB_CONFIG_READER_KERNEl_ABSTRACT_ + +#include +#include + +namespace dlib +{ + + class config_reader + { + + /*! + INITIAL VALUE + - there aren't any keys defined for this object + - there aren't any blocks defined for this object + + POINTERS AND REFERENCES TO INTERNAL DATA + The destructor, clear(), and load_from() invalidate pointers + and references to internal data. All other functions are guaranteed + to NOT invalidate pointers or references to internal data. + + WHAT THIS OBJECT REPRESENTS + This object represents something which is intended to be used to read + text configuration files that are defined by the following EBNF (with + config_file as the starting symbol): + + config_file = block; + block = { key_value_pair | sub_block }; + key_value_pair = key_name, "=", value; + sub_block = block_name, "{", block, "}"; + + key_name = identifier; + block_name = identifier; + value = matches any string of text that ends with a newline character, # or }. + note that the trailing newline, # or } is not part of the value though. + identifier = Any string that matches the following regular expression: + [a-zA-Z][a-zA-Z0-9_-\.]* + i.e. Any string that starts with a letter and then is continued + with any number of letters, numbers, _ . or - characters. + + Whitespace and comments are ignored. A comment is text that starts with # (but not \# + since the \ escapes the # so that you can have a # symbol in a value if you want) and + ends in a new line. You can also escape a } (e.g. "\}") if you want to have one in a + value. + + Note that in a value the leading and trailing white spaces are stripped off but any + white space inside the value is preserved. + + Also note that all key_names and block_names within a block syntax group must be unique + but don't have to be globally unique. I.e. different blocks can reuse names. + + EXAMPLE CONFIG FILES: + + Example 1: + #comment. This line is ignored because it starts with # + + #here we have key1 which will have the value of "my value" + key1 = my value + + another_key= another value # this is another key called "another_key" with + # a value of "another value" + + # this key's value is the empty string. I.e. "" + key2= + + Example 2: + #this example illustrates the use of blocks + some_key = blah blah + + # now here is a block + our_block + { + # here we can define some keys and values that are local to this block. + a_key = something + foo = bar + some_key = more stuff # note that it is ok to name our key this even though + # there is a key called some_key above. This is because + # we are doing so inside a different block + } + + another_block { foo = bar2 } # this block has only one key and is all on a single line + !*/ + + public: + + // exception classes + class config_reader_error : public dlib::error + { + /*! + GENERAL + This exception is thrown if there is an error while parsing the + config file. The type member of this exception will be set + to ECONFIG_READER. + + INTERPRETING THIS EXCEPTION + - line_number == the line number the parser was at when the + error occurred. + - if (redefinition) then + - The key or block name on line line_number has already + been defined in this scope which is an error. + - else + - Some other general syntax error was detected + !*/ + public: + const unsigned long line_number; + const bool redefinition; + }; + + class file_not_found : public dlib::error + { + /*! + GENERAL + This exception is thrown if the config file can't be opened for + some reason. The type member of this exception will be set + to ECONFIG_READER. + + INTERPRETING THIS EXCEPTION + - file_name == the name of the config file which we failed to open + !*/ + public: + const std::string file_name; + }; + + + class config_reader_access_error : public dlib::error + { + /*! + GENERAL + This exception is thrown if you try to access a key or + block that doesn't exist inside a config reader. The type + member of this exception will be set to ECONFIG_READER. + !*/ + public: + config_reader_access_error( + const std::string& block_name_, + const std::string& key_name_ + ); + /*! + ensures + - #block_name == block_name_ + - #key_name == key_name_ + !*/ + + const std::string block_name; + const std::string key_name; + }; + + // -------------------------- + + config_reader( + ); + /*! + ensures + - #*this is properly initialized + - This object will not have any keys or blocks defined in it. + throws + - std::bad_alloc + - config_reader_error + !*/ + + config_reader( + std::istream& in + ); + /*! + ensures + - #*this is properly initialized + - reads the config file to parse from the given input stream, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - This object will represent the top most block of the config file. + throws + - std::bad_alloc + - config_reader_error + !*/ + + config_reader( + const std::string& config_file + ); + /*! + ensures + - #*this is properly initialized + - parses the config file named by the config_file string. Specifically, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds in the file. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - This object will represent the top most block of the config file. + throws + - std::bad_alloc + - config_reader_error + - file_not_found + !*/ + + virtual ~config_reader( + ); + /*! + ensures + - all memory associated with *this has been released + !*/ + + void clear( + ); + /*! + ensures + - #*this has its initial value + throws + - std::bad_alloc + If this exception is thrown then *this is unusable + until clear() is called and succeeds + !*/ + + void load_from ( + std::istream& in + ); + /*! + ensures + - reads the config file to parse from the given input stream, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - *this will represent the top most block of the config file contained + in the input stream in. + throws + - std::bad_alloc + If this exception is thrown then *this is unusable + until clear() is called and succeeds + - config_reader_error + If this exception is thrown then this object will + revert to its initial value. + !*/ + + void load_from ( + const std::string& config_file + ); + /*! + ensures + - parses the config file named by the config_file string. Specifically, + parses it and loads this object up with all the sub blocks and + key/value pairs it finds in the file. + - before the load is performed, the previous state of the config file + reader is erased. So after the load the config file reader will contain + only information from the given config file. + - This object will represent the top most block of the config file. + throws + - std::bad_alloc + If this exception is thrown then *this is unusable + until clear() is called and succeeds + - config_reader_error + If this exception is thrown then this object will + revert to its initial value. + - file_not_found + If this exception is thrown then this object will + revert to its initial value. + !*/ + + bool is_key_defined ( + const std::string& key_name + ) const; + /*! + ensures + - if (there is a key with the given name defined within this config_reader's block) then + - returns true + - else + - returns false + !*/ + + bool is_block_defined ( + const std::string& block_name + ) const; + /*! + ensures + - if (there is a sub block with the given name defined within this config_reader's block) then + - returns true + - else + - returns false + !*/ + + typedef config_reader this_type; + const this_type& block ( + const std::string& block_name + ) const; + /*! + ensures + - if (is_block_defined(block_name) == true) then + - returns a const reference to the config_reader that represents the given named sub block + - else + - throws config_reader_access_error + throws + - config_reader_access_error + if this exception is thrown then its block_name field will be set to the + given block_name string. + !*/ + + const std::string& operator[] ( + const std::string& key_name + ) const; + /*! + ensures + - if (is_key_defined(key_name) == true) then + - returns a const reference to the value string associated with the given key in + this config_reader's block. + - else + - throws config_reader_access_error + throws + - config_reader_access_error + if this exception is thrown then its key_name field will be set to the + given key_name string. + !*/ + + template < + typename queue_of_strings + > + void get_keys ( + queue_of_strings& keys + ) const; + /*! + requires + - queue_of_strings is an implementation of queue/queue_kernel_abstract.h + with T set to std::string, or std::vector, or + dlib::std_vector_c + ensures + - #keys == a collection containing all the keys defined in this config_reader's block. + (i.e. for all strings str in keys it is the case that is_key_defined(str) == true) + !*/ + + template < + typename queue_of_strings + > + void get_blocks ( + queue_of_strings& blocks + ) const; + /*! + requires + - queue_of_strings is an implementation of queue/queue_kernel_abstract.h + with T set to std::string, or std::vector, or + dlib::std_vector_c + ensures + - #blocks == a collection containing the names of all the blocks defined in this + config_reader's block. + (i.e. for all strings str in blocks it is the case that is_block_defined(str) == true) + !*/ + + private: + + // restricted functions + config_reader(config_reader&); // copy constructor + config_reader& operator=(config_reader&); // assignment operator + + }; + +} + +#endif // DLIB_CONFIG_READER_KERNEl_ABSTRACT_ + diff --git a/dlib/config_reader/config_reader_thread_safe_1.h b/dlib/config_reader/config_reader_thread_safe_1.h new file mode 100644 index 0000000000000000000000000000000000000000..1ad250c99a4a0c1e65d0790c2432a727543f1a3a --- /dev/null +++ b/dlib/config_reader/config_reader_thread_safe_1.h @@ -0,0 +1,456 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONFIG_READER_THREAD_SAFe_ +#define DLIB_CONFIG_READER_THREAD_SAFe_ + +#include "config_reader_kernel_abstract.h" +#include +#include +#include +#include "../algs.h" +#include "../interfaces/enumerable.h" +#include "../threads.h" +#include "config_reader_thread_safe_abstract.h" + +namespace dlib +{ + + template < + typename config_reader_base, + typename map_string_void + > + class config_reader_thread_safe_1 + { + + /*! + CONVENTION + - get_mutex() == *m + - *cr == the config reader being extended + - block_table[x] == (void*)&block(x) + - block_table.size() == the number of blocks in *cr + - block_table[key] == a config_reader_thread_safe_1 that contains &cr.block(key) + - if (own_pointers) then + - this object owns the m and cr pointers and should delete them when destructed + !*/ + + public: + + config_reader_thread_safe_1 ( + const config_reader_base* base, + rmutex* m_ + ); + + config_reader_thread_safe_1(); + + typedef typename config_reader_base::config_reader_error config_reader_error; + typedef typename config_reader_base::config_reader_access_error config_reader_access_error; + + config_reader_thread_safe_1( + std::istream& in + ); + + config_reader_thread_safe_1( + const std::string& config_file + ); + + virtual ~config_reader_thread_safe_1( + ); + + void clear ( + ); + + void load_from ( + std::istream& in + ); + + void load_from ( + const std::string& config_file + ); + + bool is_key_defined ( + const std::string& key + ) const; + + bool is_block_defined ( + const std::string& name + ) const; + + typedef config_reader_thread_safe_1 this_type; + const this_type& block ( + const std::string& name + ) const; + + const std::string& operator[] ( + const std::string& key + ) const; + + template < + typename queue_of_strings + > + void get_keys ( + queue_of_strings& keys + ) const; + + template < + typename queue_of_strings + > + void get_blocks ( + queue_of_strings& blocks + ) const; + + inline const rmutex& get_mutex ( + ) const; + + private: + + void fill_block_table ( + ); + /*! + ensures + - block_table.size() == the number of blocks in cr + - block_table[key] == a config_reader_thread_safe_1 that contains &cr.block(key) + !*/ + + rmutex* m; + config_reader_base* cr; + map_string_void block_table; + const bool own_pointers; + + // restricted functions + config_reader_thread_safe_1(config_reader_thread_safe_1&); + config_reader_thread_safe_1& operator=(config_reader_thread_safe_1&); + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + // member function definitions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + const config_reader_base* base, + rmutex* m_ + ) : + m(m_), + cr(const_cast(base)), + own_pointers(false) + { + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + ) : + m(0), + cr(0), + own_pointers(true) + { + try + { + m = new rmutex; + cr = new config_reader_base; + } + catch (...) + { + if (m) delete m; + if (cr) delete cr; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + clear( + ) + { + auto_mutex M(*m); + cr->clear(); + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + load_from( + std::istream& in + ) + { + auto_mutex M(*m); + cr->load_from(in); + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + load_from( + const std::string& config_file + ) + { + auto_mutex M(*m); + cr->load_from(config_file); + fill_block_table(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + std::istream& in + ) : + m(0), + cr(0), + own_pointers(true) + { + try + { + m = new rmutex; + cr = new config_reader_base(in); + fill_block_table(); + } + catch (...) + { + if (m) delete m; + if (cr) delete cr; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + config_reader_thread_safe_1( + const std::string& config_file + ) : + m(0), + cr(0), + own_pointers(true) + { + try + { + m = new rmutex; + cr = new config_reader_base(config_file); + fill_block_table(); + } + catch (...) + { + if (m) delete m; + if (cr) delete cr; + throw; + } + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + config_reader_thread_safe_1:: + ~config_reader_thread_safe_1( + ) + { + if (own_pointers) + { + delete m; + delete cr; + } + + // clear out the block table + block_table.reset(); + while (block_table.move_next()) + { + delete static_cast(block_table.element().value()); + } + block_table.clear(); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + bool config_reader_thread_safe_1:: + is_key_defined ( + const std::string& key + ) const + { + auto_mutex M(*m); + return cr->is_key_defined(key); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + bool config_reader_thread_safe_1:: + is_block_defined ( + const std::string& name + ) const + { + auto_mutex M(*m); + return cr->is_block_defined(name); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + const config_reader_thread_safe_1& config_reader_thread_safe_1:: + block ( + const std::string& name + ) const + { + auto_mutex M(*m); + if (block_table.is_in_domain(name) == false) + { + throw config_reader_access_error(name,""); + } + + return *static_cast(block_table[name]); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + const std::string& config_reader_thread_safe_1:: + operator[] ( + const std::string& key + ) const + { + auto_mutex M(*m); + return (*cr)[key]; + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + template < + typename queue_of_strings + > + void config_reader_thread_safe_1:: + get_keys ( + queue_of_strings& keys + ) const + { + auto_mutex M(*m); + cr->get_keys(keys); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + template < + typename queue_of_strings + > + void config_reader_thread_safe_1:: + get_blocks ( + queue_of_strings& blocks + ) const + { + auto_mutex M(*m); + cr->get_blocks(blocks); + } + +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + const rmutex& config_reader_thread_safe_1:: + get_mutex ( + ) const + { + return *m; + } + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// private member functions +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + template < + typename config_reader_base, + typename map_string_void + > + void config_reader_thread_safe_1:: + fill_block_table ( + ) + { + using namespace std; + // first empty out the block table + block_table.reset(); + while (block_table.move_next()) + { + delete static_cast(block_table.element().value()); + } + block_table.clear(); + + std::vector blocks; + cr->get_blocks(blocks); + + // now fill the block table up to match what is in cr + for (unsigned long i = 0; i < blocks.size(); ++i) + { + config_reader_thread_safe_1* block = new config_reader_thread_safe_1(&cr->block(blocks[i]),m); + void* temp = block; + block_table.add(blocks[i],temp); + } + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONFIG_READER_THREAD_SAFe_ + + diff --git a/dlib/config_reader/config_reader_thread_safe_abstract.h b/dlib/config_reader/config_reader_thread_safe_abstract.h new file mode 100644 index 0000000000000000000000000000000000000000..25bcbae4a9c2dd009f3772c6b4879a2620cf9433 --- /dev/null +++ b/dlib/config_reader/config_reader_thread_safe_abstract.h @@ -0,0 +1,45 @@ +// Copyright (C) 2007 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#undef DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ +#ifdef DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ + +#include +#include +#include "config_reader_kernel_abstract.h" +#include "../threads/threads_kernel_abstract.h" + +namespace dlib +{ + + class config_reader_thread_safe + { + + /*! + WHAT THIS EXTENSION DOES FOR config_reader + This object extends a normal config_reader by simply wrapping all + its member functions inside mutex locks to make it safe to use + in a threaded program. + + So this object provides an interface identical to the one defined + in the config_reader/config_reader_kernel_abstract.h file except that + the rmutex returned by get_mutex() is always locked when this + object's member functions are called. + !*/ + + public: + + const rmutex& get_mutex ( + ) const; + /*! + ensures + - returns the rmutex used to make this object thread safe. i.e. returns + the rmutex that is locked when this object's functions are called. + !*/ + + }; + +} + +#endif // DLIB_CONFIG_READER_THREAD_SAFe_ABSTRACT_ + + diff --git a/dlib/console_progress_indicator.h b/dlib/console_progress_indicator.h new file mode 100644 index 0000000000000000000000000000000000000000..42e6fa679070e2eaf799a878ef207fd6b4538f63 --- /dev/null +++ b/dlib/console_progress_indicator.h @@ -0,0 +1,256 @@ +// Copyright (C) 2010 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ +#define DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ + +#include +#include +#include +#include + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + class console_progress_indicator + { + /*! + WHAT THIS OBJECT REPRESENTS + This object is a tool for reporting how long a task will take + to complete. + + For example, consider the following bit of code: + + console_progress_indicator pbar(100) + for (int i = 1; i <= 100; ++i) + { + pbar.print_status(i); + long_running_operation(); + } + + The above code will print a message to the console each iteration + which shows the current progress and how much time is remaining until + the loop terminates. + !*/ + + public: + + inline explicit console_progress_indicator ( + double target_value + ); + /*! + ensures + - #target() == target_value + !*/ + + inline void reset ( + double target_value + ); + /*! + ensures + - #target() == target_value + - performs the equivalent of: + *this = console_progress_indicator(target_value) + (i.e. resets this object with a new target value) + + !*/ + + inline double target ( + ) const; + /*! + ensures + - This object attempts to measure how much time is + left until we reach a certain targeted value. This + function returns that targeted value. + !*/ + + inline bool print_status ( + double cur, + bool always_print = false, + std::ostream& out = std::clog + ); + /*! + ensures + - print_status() assumes it is called with values which are linearly + approaching target(). It will display the current progress and attempt + to predict how much time is remaining until cur becomes equal to target(). + - prints a status message to out which indicates how much more time is + left until cur is equal to target() + - if (always_print) then + - This function prints to the screen each time it is called. + - else + - This function throttles the printing so that at most 1 message is + printed each second. Note that it won't print anything to the screen + until about one second has elapsed. This means that the first call + to print_status() never prints to the screen. + - This function returns true if it prints to the screen and false + otherwise. + !*/ + + inline void finish ( + std::ostream& out = std::cout + ) const; + /*! + ensures + - This object prints the completed progress and the elapsed time to out. + It is meant to be called after the loop we are tracking the progress of. + !*/ + + private: + + double target_val; + std::chrono::time_point start_time; + double first_val; + double seen_first_val; + std::chrono::time_point last_time; + + }; + +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- +// IMPLEMENTATION DETAILS +// ---------------------------------------------------------------------------------------- +// ---------------------------------------------------------------------------------------- + + console_progress_indicator:: + console_progress_indicator ( + double target_value + ) : + target_val(target_value), + start_time(std::chrono::steady_clock::now()), + first_val(0), + seen_first_val(false), + last_time(std::chrono::steady_clock::now()) + { + } + +// ---------------------------------------------------------------------------------------- + + bool console_progress_indicator:: + print_status ( + double cur, + bool always_print, + std::ostream& out + ) + { + const auto cur_time = std::chrono::steady_clock::now(); + + // if this is the first time print_status has been called + // then collect some information and exit. We will print status + // on the next call. + if (!seen_first_val) + { + start_time = cur_time; + last_time = cur_time; + first_val = cur; + seen_first_val = true; + return false; + } + + if ((cur_time - last_time) >= std::chrono::seconds(1) || always_print) + { + last_time = cur_time; + const auto delta_t = cur_time - start_time; + double delta_val = std::abs(cur - first_val); + + // don't do anything if cur is equal to first_val + if (delta_val < std::numeric_limits::epsilon()) + return false; + + const auto rem_time = delta_t / delta_val * std::abs(target_val - cur); + + const auto oldflags = out.flags(); + out.setf(std::ios::fixed,std::ios::floatfield); + std::streamsize ss; + + // adapt the precision based on whether the target val is an integer + if (std::trunc(target_val) == target_val) + ss = out.precision(0); + else + ss = out.precision(2); + + out << "Progress: " << cur << "/" << target_val; + ss = out.precision(2); + out << " (" << cur / target_val * 100. << "%). "; + + const auto hours = std::chrono::duration_cast(rem_time); + const auto minutes = std::chrono::duration_cast(rem_time) - hours; + const auto seconds = std::chrono::duration_cast(rem_time) - hours - minutes; + out << "Time remaining: "; + if (rem_time >= std::chrono::hours(1)) + out << hours.count() << "h "; + if (rem_time >= std::chrono::minutes(1)) + out << minutes.count() << "min "; + out << seconds.count() << "s. \r" << std::flush; + + // restore previous output flags and precision settings + out.flags(oldflags); + out.precision(ss); + + return true; + } + + return false; + } + +// ---------------------------------------------------------------------------------------- + + double console_progress_indicator:: + target ( + ) const + { + return target_val; + } + +// ---------------------------------------------------------------------------------------- + + void console_progress_indicator:: + reset ( + double target_value + ) + { + *this = console_progress_indicator(target_value); + } + +// ---------------------------------------------------------------------------------------- + + void console_progress_indicator:: + finish ( + std::ostream& out + ) const + { + const auto oldflags = out.flags(); + out.setf(std::ios::fixed,std::ios::floatfield); + std::streamsize ss; + + // adapt the precision based on whether the target val is an integer + if (std::trunc(target_val) == target_val) + ss = out.precision(0); + else + ss = out.precision(2); + + out << "Progress: " << target_val << "/" << target_val; + out << " (100.00%). "; + const auto delta_t = std::chrono::steady_clock::now() - start_time; + const auto hours = std::chrono::duration_cast(delta_t); + const auto minutes = std::chrono::duration_cast(delta_t) - hours; + const auto seconds = std::chrono::duration_cast(delta_t) - hours - minutes; + out << "Time elapsed: "; + if (delta_t >= std::chrono::hours(1)) + out << hours.count() << "h "; + if (delta_t >= std::chrono::minutes(1)) + out << minutes.count() << "min "; + out << seconds.count() << "s. " << std::endl; + + // restore previous output flags and precision settings + out.flags(oldflags); + out.precision(ss); + } + +// ---------------------------------------------------------------------------------------- + +} + +#endif // DLIB_CONSOLE_PROGRESS_INDiCATOR_Hh_ + diff --git a/dlib/constexpr_if.h b/dlib/constexpr_if.h new file mode 100644 index 0000000000000000000000000000000000000000..2c22ddf5e3e05b9f2c621a50d20d8092b2be3eae --- /dev/null +++ b/dlib/constexpr_if.h @@ -0,0 +1,136 @@ +// Copyright (C) 2022 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_IF_CONSTEXPR_H +#define DLIB_IF_CONSTEXPR_H + +#include "overloaded.h" +#include "type_traits.h" + +namespace dlib +{ +// ---------------------------------------------------------------------------------------- + + namespace detail + { + const auto _ = [](auto&& arg) -> decltype(auto) { return std::forward(arg); }; + + template class Op, class... Args> + struct is_detected : std::false_type{}; + + template