Commit d53e923b authored by pkufool's avatar pkufool
Browse files

Move k2 rnnt_loss here

parent b5828e2b
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
.idea .idea
venv* venv*
deploy* deploy*
__pycache__/*
if("x${CMAKE_SOURCE_DIR}" STREQUAL "x${CMAKE_BINARY_DIR}")
message(FATAL_ERROR "\
In-source build is not a good practice.
Please use:
mkdir build
cd build
cmake ..
to build this project"
)
endif()
cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
set(languages CXX)
set(_FT_WITH_CUDA ON)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_program(FT_HAS_NVCC nvcc)
if(NOT FT_HAS_NVCC AND "$ENV{CUDACXX}" STREQUAL "")
message(STATUS "No NVCC detected. Disable CUDA support")
set(_FT_WITH_CUDA OFF)
endif()
if(APPLE OR (DEFINED FT_WITH_CUDA AND NOT FT_WITH_CUDA))
if(_FT_WITH_CUDA)
message(STATUS "Disable CUDA support")
set(_FT_WITH_CUDA OFF)
endif()
endif()
if(_FT_WITH_CUDA)
set(languages ${languages} CUDA)
if(NOT DEFINED FT_WITH_CUDA)
set(FT_WITH_CUDA ON)
endif()
endif()
message(STATUS "Enabled languages: ${languages}")
project(fast_rnnt ${languages})
set(FT_VERSION "1.0")
set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel)
set(DEFAULT_BUILD_TYPE "Release")
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "${ALLOWABLE_BUILD_TYPES}")
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
# CMAKE_CONFIGURATION_TYPES: with config type values from other generators (IDE).
message(STATUS "No CMAKE_BUILD_TYPE given, default to ${DEFAULT_BUILD_TYPE}")
set(CMAKE_BUILD_TYPE "${DEFAULT_BUILD_TYPE}")
elseif(NOT CMAKE_BUILD_TYPE IN_LIST ALLOWABLE_BUILD_TYPES)
message(FATAL_ERROR "Invalid build type: ${CMAKE_BUILD_TYPE}, \
choose one from ${ALLOWABLE_BUILD_TYPES}")
endif()
option(FT_BUILD_TESTS "Whether to build tests or not" OFF)
option(BUILD_SHARED_LIBS "Whether to build shared libs" ON)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
set(CMAKE_SKIP_BUILD_RPATH FALSE)
set(BUILD_RPATH_USE_ORIGIN TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
set(CMAKE_INSTALL_RPATH "$ORIGIN")
set(CMAKE_BUILD_RPATH "$ORIGIN")
if(FT_WITH_CUDA)
add_definitions(-DFT_WITH_CUDA)
# Force CUDA C++ standard to be the same as the C++ standard used.
#
# Now, CMake is unaligned with reality on standard versions: https://gitlab.kitware.com/cmake/cmake/issues/18597
# which means that using standard CMake methods, it's impossible to actually sync the CXX and CUDA versions for pre-11
# versions of C++; CUDA accepts 98 but translates that to 03, while CXX doesn't accept 03 (and doesn't translate that to 03).
# In case this gives You, dear user, any trouble, please escalate the above CMake bug, so we can support reality properly.
if(DEFINED CMAKE_CUDA_STANDARD)
message(WARNING "You've set CMAKE_CUDA_STANDARD; please note that this variable is ignored, and CMAKE_CXX_STANDARD"
" is used as the C++ standard version for both C++ and CUDA.")
endif()
unset(CMAKE_CUDA_STANDARD CACHE)
set(CMAKE_CUDA_STANDARD ${CMAKE_CXX_STANDARD})
include(cmake/select_compute_arch.cmake)
cuda_select_nvcc_arch_flags(FT_COMPUTE_ARCH_FLAGS)
message(STATUS "FT_COMPUTE_ARCH_FLAGS: ${FT_COMPUTE_ARCH_FLAGS}")
# set(OT_COMPUTE_ARCHS 30 32 35 50 52 53 60 61 62 70 72)
# message(WARNING "arch 62/72 are not supported for now")
# see https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/
# https://www.myzhar.com/blog/tutorials/tutorial-nvidia-gpu-cuda-compute-capability/
set(FT_COMPUTE_ARCH_CANDIDATES 35 50 60 61 70 75)
if(CUDA_VERSION VERSION_GREATER "11.0")
list(APPEND FT_COMPUTE_ARCH_CANDIDATES 80 86)
endif()
message(STATUS "FT_COMPUTE_ARCH_CANDIDATES ${FT_COMPUTE_ARCH_CANDIDATES}")
set(FT_COMPUTE_ARCHS)
foreach(COMPUTE_ARCH IN LISTS FT_COMPUTE_ARCH_CANDIDATES)
if("${FT_COMPUTE_ARCH_FLAGS}" MATCHES ${COMPUTE_ARCH})
message(STATUS "Adding arch ${COMPUTE_ARCH}")
list(APPEND FT_COMPUTE_ARCHS ${COMPUTE_ARCH})
else()
message(STATUS "Skipping arch ${COMPUTE_ARCH}")
endif()
endforeach()
if(NOT FT_COMPUTE_ARCHS)
set(FT_COMPUTE_ARCHS ${FT_COMPUTE_ARCH_CANDIDATES})
endif()
message(STATUS "FT_COMPUTE_ARCHS: ${FT_COMPUTE_ARCHS}")
foreach(COMPUTE_ARCH IN LISTS FT_COMPUTE_ARCHS)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda -gencode arch=compute_${COMPUTE_ARCH},code=sm_${COMPUTE_ARCH}")
set(CMAKE_CUDA_ARCHITECTURES "${COMPUTE_ARCH}-real;${COMPUTE_ARCH}-virtual;${CMAKE_CUDA_ARCHITECTURES}")
endforeach()
endif()
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include(pybind11)
include(torch)
if(FT_BUILD_TESTS)
enable_testing()
include(googletest)
endif()
add_subdirectory(fast_rnnt)
# Distributed under the OSI-approved BSD 3-Clause License. See accompanying
# file Copyright.txt or https://cmake.org/licensing for details.
#[=======================================================================[.rst:
FetchContent
------------------
.. only:: html
.. contents::
Overview
^^^^^^^^
This module enables populating content at configure time via any method
supported by the :module:`ExternalProject` module. Whereas
:command:`ExternalProject_Add` downloads at build time, the
``FetchContent`` module makes content available immediately, allowing the
configure step to use the content in commands like :command:`add_subdirectory`,
:command:`include` or :command:`file` operations.
Content population details would normally be defined separately from the
command that performs the actual population. Projects should also
check whether the content has already been populated somewhere else in the
project hierarchy. Typical usage would look something like this:
.. code-block:: cmake
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG release-1.8.0
)
FetchContent_GetProperties(googletest)
if(NOT googletest_POPULATED)
FetchContent_Populate(googletest)
add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR})
endif()
When using the above pattern with a hierarchical project arrangement,
projects at higher levels in the hierarchy are able to define or override
the population details of content specified anywhere lower in the project
hierarchy. The ability to detect whether content has already been
populated ensures that even if multiple child projects want certain content
to be available, the first one to populate it wins. The other child project
can simply make use of the already available content instead of repeating
the population for itself. See the
:ref:`Examples <fetch-content-examples>` section which demonstrates
this scenario.
The ``FetchContent`` module also supports defining and populating
content in a single call, with no check for whether the content has been
populated elsewhere in the project already. This is a more low level
operation and would not normally be the way the module is used, but it is
sometimes useful as part of implementing some higher level feature or to
populate some content in CMake's script mode.
Declaring Content Details
^^^^^^^^^^^^^^^^^^^^^^^^^
.. command:: FetchContent_Declare
.. code-block:: cmake
FetchContent_Declare(<name> <contentOptions>...)
The ``FetchContent_Declare()`` function records the options that describe
how to populate the specified content, but if such details have already
been recorded earlier in this project (regardless of where in the project
hierarchy), this and all later calls for the same content ``<name>`` are
ignored. This "first to record, wins" approach is what allows hierarchical
projects to have parent projects override content details of child projects.
The content ``<name>`` can be any string without spaces, but good practice
would be to use only letters, numbers and underscores. The name will be
treated case-insensitively and it should be obvious for the content it
represents, often being the name of the child project or the value given
to its top level :command:`project` command (if it is a CMake project).
For well-known public projects, the name should generally be the official
name of the project. Choosing an unusual name makes it unlikely that other
projects needing that same content will use the same name, leading to
the content being populated multiple times.
The ``<contentOptions>`` can be any of the download or update/patch options
that the :command:`ExternalProject_Add` command understands. The configure,
build, install and test steps are explicitly disabled and therefore options
related to them will be ignored. In most cases, ``<contentOptions>`` will
just be a couple of options defining the download method and method-specific
details like a commit tag or archive hash. For example:
.. code-block:: cmake
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG release-1.8.0
)
FetchContent_Declare(
myCompanyIcons
URL https://intranet.mycompany.com/assets/iconset_1.12.tar.gz
URL_HASH 5588a7b18261c20068beabfb4f530b87
)
FetchContent_Declare(
myCompanyCertificates
SVN_REPOSITORY svn+ssh://svn.mycompany.com/srv/svn/trunk/certs
SVN_REVISION -r12345
)
Populating The Content
^^^^^^^^^^^^^^^^^^^^^^
.. command:: FetchContent_Populate
.. code-block:: cmake
FetchContent_Populate( <name> )
In most cases, the only argument given to ``FetchContent_Populate()`` is the
``<name>``. When used this way, the command assumes the content details have
been recorded by an earlier call to :command:`FetchContent_Declare`. The
details are stored in a global property, so they are unaffected by things
like variable or directory scope. Therefore, it doesn't matter where in the
project the details were previously declared, as long as they have been
declared before the call to ``FetchContent_Populate()``. Those saved details
are then used to construct a call to :command:`ExternalProject_Add` in a
private sub-build to perform the content population immediately. The
implementation of ``ExternalProject_Add()`` ensures that if the content has
already been populated in a previous CMake run, that content will be reused
rather than repopulating them again. For the common case where population
involves downloading content, the cost of the download is only paid once.
An internal global property records when a particular content population
request has been processed. If ``FetchContent_Populate()`` is called more
than once for the same content name within a configure run, the second call
will halt with an error. Projects can and should check whether content
population has already been processed with the
:command:`FetchContent_GetProperties` command before calling
``FetchContent_Populate()``.
``FetchContent_Populate()`` will set three variables in the scope of the
caller; ``<lcName>_POPULATED``, ``<lcName>_SOURCE_DIR`` and
``<lcName>_BINARY_DIR``, where ``<lcName>`` is the lowercased ``<name>``.
``<lcName>_POPULATED`` will always be set to ``True`` by the call.
``<lcName>_SOURCE_DIR`` is the location where the
content can be found upon return (it will have already been populated), while
``<lcName>_BINARY_DIR`` is a directory intended for use as a corresponding
build directory. The main use case for the two directory variables is to
call :command:`add_subdirectory` immediately after population, i.e.:
.. code-block:: cmake
FetchContent_Populate(FooBar ...)
add_subdirectory(${foobar_SOURCE_DIR} ${foobar_BINARY_DIR})
The values of the three variables can also be retrieved from anywhere in the
project hierarchy using the :command:`FetchContent_GetProperties` command.
A number of cache variables influence the behavior of all content population
performed using details saved from a :command:`FetchContent_Declare` call:
``FETCHCONTENT_BASE_DIR``
In most cases, the saved details do not specify any options relating to the
directories to use for the internal sub-build, final source and build areas.
It is generally best to leave these decisions up to the ``FetchContent``
module to handle on the project's behalf. The ``FETCHCONTENT_BASE_DIR``
cache variable controls the point under which all content population
directories are collected, but in most cases developers would not need to
change this. The default location is ``${CMAKE_BINARY_DIR}/_deps``, but if
developers change this value, they should aim to keep the path short and
just below the top level of the build tree to avoid running into path
length problems on Windows.
``FETCHCONTENT_QUIET``
The logging output during population can be quite verbose, making the
configure stage quite noisy. This cache option (``ON`` by default) hides
all population output unless an error is encountered. If experiencing
problems with hung downloads, temporarily switching this option off may
help diagnose which content population is causing the issue.
``FETCHCONTENT_FULLY_DISCONNECTED``
When this option is enabled, no attempt is made to download or update
any content. It is assumed that all content has already been populated in
a previous run or the source directories have been pointed at existing
contents the developer has provided manually (using options described
further below). When the developer knows that no changes have been made to
any content details, turning this option ``ON`` can significantly speed up
the configure stage. It is ``OFF`` by default.
``FETCHCONTENT_UPDATES_DISCONNECTED``
This is a less severe download/update control compared to
``FETCHCONTENT_FULLY_DISCONNECTED``. Instead of bypassing all download and
update logic, the ``FETCHCONTENT_UPDATES_DISCONNECTED`` only disables the
update stage. Therefore, if content has not been downloaded previously,
it will still be downloaded when this option is enabled. This can speed up
the configure stage, but not as much as
``FETCHCONTENT_FULLY_DISCONNECTED``. It is ``OFF`` by default.
In addition to the above cache variables, the following cache variables are
also defined for each content name (``<ucName>`` is the uppercased value of
``<name>``):
``FETCHCONTENT_SOURCE_DIR_<ucName>``
If this is set, no download or update steps are performed for the specified
content and the ``<lcName>_SOURCE_DIR`` variable returned to the caller is
pointed at this location. This gives developers a way to have a separate
checkout of the content that they can modify freely without interference
from the build. The build simply uses that existing source, but it still
defines ``<lcName>_BINARY_DIR`` to point inside its own build area.
Developers are strongly encouraged to use this mechanism rather than
editing the sources populated in the default location, as changes to
sources in the default location can be lost when content population details
are changed by the project.
``FETCHCONTENT_UPDATES_DISCONNECTED_<ucName>``
This is the per-content equivalent of
``FETCHCONTENT_UPDATES_DISCONNECTED``. If the global option or this option
is ``ON``, then updates will be disabled for the named content.
Disabling updates for individual content can be useful for content whose
details rarely change, while still leaving other frequently changing
content with updates enabled.
The ``FetchContent_Populate()`` command also supports a syntax allowing the
content details to be specified directly rather than using any saved
details. This is more low-level and use of this form is generally to be
avoided in favour of using saved content details as outlined above.
Nevertheless, in certain situations it can be useful to invoke the content
population as an isolated operation (typically as part of implementing some
other higher level feature or when using CMake in script mode):
.. code-block:: cmake
FetchContent_Populate( <name>
[QUIET]
[SUBBUILD_DIR <subBuildDir>]
[SOURCE_DIR <srcDir>]
[BINARY_DIR <binDir>]
...
)
This form has a number of key differences to that where only ``<name>`` is
provided:
- All required population details are assumed to have been provided directly
in the call to ``FetchContent_Populate()``. Any saved details for
``<name>`` are ignored.
- No check is made for whether content for ``<name>`` has already been
populated.
- No global property is set to record that the population has occurred.
- No global properties record the source or binary directories used for the
populated content.
- The ``FETCHCONTENT_FULLY_DISCONNECTED`` and
``FETCHCONTENT_UPDATES_DISCONNECTED`` cache variables are ignored.
The ``<lcName>_SOURCE_DIR`` and ``<lcName>_BINARY_DIR`` variables are still
returned to the caller, but since these locations are not stored as global
properties when this form is used, they are only available to the calling
scope and below rather than the entire project hierarchy. No
``<lcName>_POPULATED`` variable is set in the caller's scope with this form.
The supported options for ``FetchContent_Populate()`` are the same as those
for :command:`FetchContent_Declare()`. Those few options shown just
above are either specific to ``FetchContent_Populate()`` or their behavior is
slightly modified from how :command:`ExternalProject_Add` treats them.
``QUIET``
The ``QUIET`` option can be given to hide the output associated with
populating the specified content. If the population fails, the output will
be shown regardless of whether this option was given or not so that the
cause of the failure can be diagnosed. The global ``FETCHCONTENT_QUIET``
cache variable has no effect on ``FetchContent_Populate()`` calls where the
content details are provided directly.
``SUBBUILD_DIR``
The ``SUBBUILD_DIR`` argument can be provided to change the location of the
sub-build created to perform the population. The default value is
``${CMAKE_CURRENT_BINARY_DIR}/<lcName>-subbuild`` and it would be unusual
to need to override this default. If a relative path is specified, it will
be interpreted as relative to :variable:`CMAKE_CURRENT_BINARY_DIR`.
``SOURCE_DIR``, ``BINARY_DIR``
The ``SOURCE_DIR`` and ``BINARY_DIR`` arguments are supported by
:command:`ExternalProject_Add`, but different default values are used by
``FetchContent_Populate()``. ``SOURCE_DIR`` defaults to
``${CMAKE_CURRENT_BINARY_DIR}/<lcName>-src`` and ``BINARY_DIR`` defaults to
``${CMAKE_CURRENT_BINARY_DIR}/<lcName>-build``. If a relative path is
specified, it will be interpreted as relative to
:variable:`CMAKE_CURRENT_BINARY_DIR`.
In addition to the above explicit options, any other unrecognized options are
passed through unmodified to :command:`ExternalProject_Add` to perform the
download, patch and update steps. The following options are explicitly
prohibited (they are disabled by the ``FetchContent_Populate()`` command):
- ``CONFIGURE_COMMAND``
- ``BUILD_COMMAND``
- ``INSTALL_COMMAND``
- ``TEST_COMMAND``
If using ``FetchContent_Populate()`` within CMake's script mode, be aware
that the implementation sets up a sub-build which therefore requires a CMake
generator and build tool to be available. If these cannot be found by
default, then the :variable:`CMAKE_GENERATOR` and/or
:variable:`CMAKE_MAKE_PROGRAM` variables will need to be set appropriately
on the command line invoking the script.
Retrieve Population Properties
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. command:: FetchContent_GetProperties
When using saved content details, a call to :command:`FetchContent_Populate`
records information in global properties which can be queried at any time.
This information includes the source and binary directories associated with
the content and also whether or not the content population has been processed
during the current configure run.
.. code-block:: cmake
FetchContent_GetProperties( <name>
[SOURCE_DIR <srcDirVar>]
[BINARY_DIR <binDirVar>]
[POPULATED <doneVar>]
)
The ``SOURCE_DIR``, ``BINARY_DIR`` and ``POPULATED`` options can be used to
specify which properties should be retrieved. Each option accepts a value
which is the name of the variable in which to store that property. Most of
the time though, only ``<name>`` is given, in which case the call will then
set the same variables as a call to
:command:`FetchContent_Populate(name) <FetchContent_Populate>`. This allows
the following canonical pattern to be used, which ensures that the relevant
variables will always be defined regardless of whether or not the population
has been performed elsewhere in the project already:
.. code-block:: cmake
FetchContent_GetProperties(foobar)
if(NOT foobar_POPULATED)
FetchContent_Populate(foobar)
# Set any custom variables, etc. here, then
# populate the content as part of this build
add_subdirectory(${foobar_SOURCE_DIR} ${foobar_BINARY_DIR})
endif()
The above pattern allows other parts of the overall project hierarchy to
re-use the same content and ensure that it is only populated once.
.. _`fetch-content-examples`:
Examples
^^^^^^^^
Consider a project hierarchy where ``projA`` is the top level project and it
depends on projects ``projB`` and ``projC``. Both ``projB`` and ``projC``
can be built standalone and they also both depend on another project
``projD``. For simplicity, this example will assume that all four projects
are available on a company git server. The ``CMakeLists.txt`` of each project
might have sections like the following:
*projA*:
.. code-block:: cmake
include(FetchContent)
FetchContent_Declare(
projB
GIT_REPOSITORY git@mycompany.com/git/projB.git
GIT_TAG 4a89dc7e24ff212a7b5167bef7ab079d
)
FetchContent_Declare(
projC
GIT_REPOSITORY git@mycompany.com/git/projC.git
GIT_TAG 4ad4016bd1d8d5412d135cf8ceea1bb9
)
FetchContent_Declare(
projD
GIT_REPOSITORY git@mycompany.com/git/projD.git
GIT_TAG origin/integrationBranch
)
FetchContent_GetProperties(projB)
if(NOT projb_POPULATED)
FetchContent_Populate(projB)
add_subdirectory(${projb_SOURCE_DIR} ${projb_BINARY_DIR})
endif()
FetchContent_GetProperties(projC)
if(NOT projc_POPULATED)
FetchContent_Populate(projC)
add_subdirectory(${projc_SOURCE_DIR} ${projc_BINARY_DIR})
endif()
*projB*:
.. code-block:: cmake
include(FetchContent)
FetchContent_Declare(
projD
GIT_REPOSITORY git@mycompany.com/git/projD.git
GIT_TAG 20b415f9034bbd2a2e8216e9a5c9e632
)
FetchContent_GetProperties(projD)
if(NOT projd_POPULATED)
FetchContent_Populate(projD)
add_subdirectory(${projd_SOURCE_DIR} ${projd_BINARY_DIR})
endif()
*projC*:
.. code-block:: cmake
include(FetchContent)
FetchContent_Declare(
projD
GIT_REPOSITORY git@mycompany.com/git/projD.git
GIT_TAG 7d9a17ad2c962aa13e2fbb8043fb6b8a
)
FetchContent_GetProperties(projD)
if(NOT projd_POPULATED)
FetchContent_Populate(projD)
add_subdirectory(${projd_SOURCE_DIR} ${projd_BINARY_DIR})
endif()
A few key points should be noted in the above:
- ``projB`` and ``projC`` define different content details for ``projD``,
but ``projA`` also defines a set of content details for ``projD`` and
because ``projA`` will define them first, the details from ``projB`` and
``projC`` will not be used. The override details defined by ``projA``
are not required to match either of those from ``projB`` or ``projC``, but
it is up to the higher level project to ensure that the details it does
define still make sense for the child projects.
- While ``projA`` defined content details for ``projD``, it did not need
to explicitly call ``FetchContent_Populate(projD)`` itself. Instead, it
leaves that to a child project to do (in this case it will be ``projB``
since it is added to the build ahead of ``projC``). If ``projA`` needed to
customize how the ``projD`` content was brought into the build as well
(e.g. define some CMake variables before calling
:command:`add_subdirectory` after populating), it would do the call to
``FetchContent_Populate()``, etc. just as it did for the ``projB`` and
``projC`` content. For higher level projects, it is usually enough to
just define the override content details and leave the actual population
to the child projects. This saves repeating the same thing at each level
of the project hierarchy unnecessarily.
- Even though ``projA`` is the top level project in this example, it still
checks whether ``projB`` and ``projC`` have already been populated before
going ahead to do those populations. This makes ``projA`` able to be more
easily incorporated as a child of some other higher level project in the
future if required. Always protect a call to
:command:`FetchContent_Populate` with a check to
:command:`FetchContent_GetProperties`, even in what may be considered a top
level project at the time.
The following example demonstrates how one might download and unpack a
firmware tarball using CMake's :manual:`script mode <cmake(1)>`. The call to
:command:`FetchContent_Populate` specifies all the content details and the
unpacked firmware will be placed in a ``firmware`` directory below the
current working directory.
*getFirmware.cmake*:
.. code-block:: cmake
# NOTE: Intended to be run in script mode with cmake -P
include(FetchContent)
FetchContent_Populate(
firmware
URL https://mycompany.com/assets/firmware-1.23-arm.tar.gz
URL_HASH MD5=68247684da89b608d466253762b0ff11
SOURCE_DIR firmware
)
#]=======================================================================]
set(__FetchContent_privateDir "${CMAKE_CURRENT_LIST_DIR}/FetchContent")
#=======================================================================
# Recording and retrieving content details for later population
#=======================================================================
# Internal use, projects must not call this directly. It is
# intended for use by FetchContent_Declare() only.
#
# Sets a content-specific global property (not meant for use
# outside of functions defined here in this file) which can later
# be retrieved using __FetchContent_getSavedDetails() with just the
# same content name. If there is already a value stored in the
# property, it is left unchanged and this call has no effect.
# This allows parent projects to define the content details,
# overriding anything a child project may try to set (properties
# are not cached between runs, so the first thing to set it in a
# build will be in control).
function(__FetchContent_declareDetails contentName)
string(TOLOWER ${contentName} contentNameLower)
set(propertyName "_FetchContent_${contentNameLower}_savedDetails")
get_property(alreadyDefined GLOBAL PROPERTY ${propertyName} DEFINED)
if(NOT alreadyDefined)
define_property(GLOBAL PROPERTY ${propertyName}
BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()"
FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}"
)
set_property(GLOBAL PROPERTY ${propertyName} ${ARGN})
endif()
endfunction()
# Internal use, projects must not call this directly. It is
# intended for use by the FetchContent_Declare() function.
#
# Retrieves details saved for the specified content in an
# earlier call to __FetchContent_declareDetails().
function(__FetchContent_getSavedDetails contentName outVar)
string(TOLOWER ${contentName} contentNameLower)
set(propertyName "_FetchContent_${contentNameLower}_savedDetails")
get_property(alreadyDefined GLOBAL PROPERTY ${propertyName} DEFINED)
if(NOT alreadyDefined)
message(FATAL_ERROR "No content details recorded for ${contentName}")
endif()
get_property(propertyValue GLOBAL PROPERTY ${propertyName})
set(${outVar} "${propertyValue}" PARENT_SCOPE)
endfunction()
# Saves population details of the content, sets defaults for the
# SOURCE_DIR and BUILD_DIR.
function(FetchContent_Declare contentName)
set(options "")
set(oneValueArgs SVN_REPOSITORY)
set(multiValueArgs "")
cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
unset(srcDirSuffix)
unset(svnRepoArgs)
if(ARG_SVN_REPOSITORY)
# Add a hash of the svn repository URL to the source dir. This works
# around the problem where if the URL changes, the download would
# fail because it tries to checkout/update rather than switch the
# old URL to the new one. We limit the hash to the first 7 characters
# so that the source path doesn't get overly long (which can be a
# problem on windows due to path length limits).
string(SHA1 urlSHA ${ARG_SVN_REPOSITORY})
string(SUBSTRING ${urlSHA} 0 7 urlSHA)
set(srcDirSuffix "-${urlSHA}")
set(svnRepoArgs SVN_REPOSITORY ${ARG_SVN_REPOSITORY})
endif()
string(TOLOWER ${contentName} contentNameLower)
__FetchContent_declareDetails(
${contentNameLower}
SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-src${srcDirSuffix}"
BINARY_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-build"
${svnRepoArgs}
# List these last so they can override things we set above
${ARG_UNPARSED_ARGUMENTS}
)
endfunction()
#=======================================================================
# Set/get whether the specified content has been populated yet.
# The setter also records the source and binary dirs used.
#=======================================================================
# Internal use, projects must not call this directly. It is
# intended for use by the FetchContent_Populate() function to
# record when FetchContent_Populate() is called for a particular
# content name.
function(__FetchContent_setPopulated contentName sourceDir binaryDir)
string(TOLOWER ${contentName} contentNameLower)
set(prefix "_FetchContent_${contentNameLower}")
set(propertyName "${prefix}_sourceDir")
define_property(GLOBAL PROPERTY ${propertyName}
BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()"
FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}"
)
set_property(GLOBAL PROPERTY ${propertyName} ${sourceDir})
set(propertyName "${prefix}_binaryDir")
define_property(GLOBAL PROPERTY ${propertyName}
BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()"
FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}"
)
set_property(GLOBAL PROPERTY ${propertyName} ${binaryDir})
set(propertyName "${prefix}_populated")
define_property(GLOBAL PROPERTY ${propertyName}
BRIEF_DOCS "Internal implementation detail of FetchContent_Populate()"
FULL_DOCS "Details used by FetchContent_Populate() for ${contentName}"
)
set_property(GLOBAL PROPERTY ${propertyName} True)
endfunction()
# Set variables in the calling scope for any of the retrievable
# properties. If no specific properties are requested, variables
# will be set for all retrievable properties.
#
# This function is intended to also be used by projects as the canonical
# way to detect whether they should call FetchContent_Populate()
# and pull the populated source into the build with add_subdirectory(),
# if they are using the populated content in that way.
function(FetchContent_GetProperties contentName)
string(TOLOWER ${contentName} contentNameLower)
set(options "")
set(oneValueArgs SOURCE_DIR BINARY_DIR POPULATED)
set(multiValueArgs "")
cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(NOT ARG_SOURCE_DIR AND
NOT ARG_BINARY_DIR AND
NOT ARG_POPULATED)
# No specific properties requested, provide them all
set(ARG_SOURCE_DIR ${contentNameLower}_SOURCE_DIR)
set(ARG_BINARY_DIR ${contentNameLower}_BINARY_DIR)
set(ARG_POPULATED ${contentNameLower}_POPULATED)
endif()
set(prefix "_FetchContent_${contentNameLower}")
if(ARG_SOURCE_DIR)
set(propertyName "${prefix}_sourceDir")
get_property(value GLOBAL PROPERTY ${propertyName})
if(value)
set(${ARG_SOURCE_DIR} ${value} PARENT_SCOPE)
endif()
endif()
if(ARG_BINARY_DIR)
set(propertyName "${prefix}_binaryDir")
get_property(value GLOBAL PROPERTY ${propertyName})
if(value)
set(${ARG_BINARY_DIR} ${value} PARENT_SCOPE)
endif()
endif()
if(ARG_POPULATED)
set(propertyName "${prefix}_populated")
get_property(value GLOBAL PROPERTY ${propertyName} DEFINED)
set(${ARG_POPULATED} ${value} PARENT_SCOPE)
endif()
endfunction()
#=======================================================================
# Performing the population
#=======================================================================
# The value of contentName will always have been lowercased by the caller.
# All other arguments are assumed to be options that are understood by
# ExternalProject_Add(), except for QUIET and SUBBUILD_DIR.
function(__FetchContent_directPopulate contentName)
set(options
QUIET
)
set(oneValueArgs
SUBBUILD_DIR
SOURCE_DIR
BINARY_DIR
# Prevent the following from being passed through
CONFIGURE_COMMAND
BUILD_COMMAND
INSTALL_COMMAND
TEST_COMMAND
)
set(multiValueArgs "")
cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(NOT ARG_SUBBUILD_DIR)
message(FATAL_ERROR "Internal error: SUBBUILD_DIR not set")
elseif(NOT IS_ABSOLUTE "${ARG_SUBBUILD_DIR}")
set(ARG_SUBBUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/${ARG_SUBBUILD_DIR}")
endif()
if(NOT ARG_SOURCE_DIR)
message(FATAL_ERROR "Internal error: SOURCE_DIR not set")
elseif(NOT IS_ABSOLUTE "${ARG_SOURCE_DIR}")
set(ARG_SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/${ARG_SOURCE_DIR}")
endif()
if(NOT ARG_BINARY_DIR)
message(FATAL_ERROR "Internal error: BINARY_DIR not set")
elseif(NOT IS_ABSOLUTE "${ARG_BINARY_DIR}")
set(ARG_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/${ARG_BINARY_DIR}")
endif()
# Ensure the caller can know where to find the source and build directories
# with some convenient variables. Doing this here ensures the caller sees
# the correct result in the case where the default values are overridden by
# the content details set by the project.
set(${contentName}_SOURCE_DIR "${ARG_SOURCE_DIR}" PARENT_SCOPE)
set(${contentName}_BINARY_DIR "${ARG_BINARY_DIR}" PARENT_SCOPE)
# The unparsed arguments may contain spaces, so build up ARG_EXTRA
# in such a way that it correctly substitutes into the generated
# CMakeLists.txt file with each argument quoted.
unset(ARG_EXTRA)
foreach(arg IN LISTS ARG_UNPARSED_ARGUMENTS)
set(ARG_EXTRA "${ARG_EXTRA} \"${arg}\"")
endforeach()
# Hide output if requested, but save it to a variable in case there's an
# error so we can show the output upon failure. When not quiet, don't
# capture the output to a variable because the user may want to see the
# output as it happens (e.g. progress during long downloads). Combine both
# stdout and stderr in the one capture variable so the output stays in order.
if (ARG_QUIET)
set(outputOptions
OUTPUT_VARIABLE capturedOutput
ERROR_VARIABLE capturedOutput
)
else()
set(capturedOutput)
set(outputOptions)
message(STATUS "Populating ${contentName}")
endif()
if(CMAKE_GENERATOR)
set(generatorOpts "-G${CMAKE_GENERATOR}")
if(CMAKE_GENERATOR_PLATFORM)
list(APPEND generatorOpts "-A${CMAKE_GENERATOR_PLATFORM}")
endif()
if(CMAKE_GENERATOR_TOOLSET)
list(APPEND generatorOpts "-T${CMAKE_GENERATOR_TOOLSET}")
endif()
if(CMAKE_MAKE_PROGRAM)
list(APPEND generatorOpts "-DCMAKE_MAKE_PROGRAM:FILEPATH=${CMAKE_MAKE_PROGRAM}")
endif()
else()
# Likely we've been invoked via CMake's script mode where no
# generator is set (and hence CMAKE_MAKE_PROGRAM could not be
# trusted even if provided). We will have to rely on being
# able to find the default generator and build tool.
unset(generatorOpts)
endif()
# Create and build a separate CMake project to carry out the population.
# If we've already previously done these steps, they will not cause
# anything to be updated, so extra rebuilds of the project won't occur.
# Make sure to pass through CMAKE_MAKE_PROGRAM in case the main project
# has this set to something not findable on the PATH.
configure_file("${__FetchContent_privateDir}/CMakeLists.cmake.in"
"${ARG_SUBBUILD_DIR}/CMakeLists.txt")
execute_process(
COMMAND ${CMAKE_COMMAND} ${generatorOpts} .
RESULT_VARIABLE result
${outputOptions}
WORKING_DIRECTORY "${ARG_SUBBUILD_DIR}"
)
if(result)
if(capturedOutput)
message("${capturedOutput}")
endif()
message(FATAL_ERROR "CMake step for ${contentName} failed: ${result}")
endif()
execute_process(
COMMAND ${CMAKE_COMMAND} --build .
RESULT_VARIABLE result
${outputOptions}
WORKING_DIRECTORY "${ARG_SUBBUILD_DIR}"
)
if(result)
if(capturedOutput)
message("${capturedOutput}")
endif()
message(FATAL_ERROR "Build step for ${contentName} failed: ${result}")
endif()
endfunction()
option(FETCHCONTENT_FULLY_DISCONNECTED "Disables all attempts to download or update content and assumes source dirs already exist")
option(FETCHCONTENT_UPDATES_DISCONNECTED "Enables UPDATE_DISCONNECTED behavior for all content population")
option(FETCHCONTENT_QUIET "Enables QUIET option for all content population" ON)
set(FETCHCONTENT_BASE_DIR "${CMAKE_BINARY_DIR}/_deps" CACHE PATH "Directory under which to collect all populated content")
# Populate the specified content using details stored from
# an earlier call to FetchContent_Declare().
function(FetchContent_Populate contentName)
if(NOT contentName)
message(FATAL_ERROR "Empty contentName not allowed for FetchContent_Populate()")
endif()
string(TOLOWER ${contentName} contentNameLower)
if(ARGN)
# This is the direct population form with details fully specified
# as part of the call, so we already have everything we need
__FetchContent_directPopulate(
${contentNameLower}
SUBBUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/${contentNameLower}-subbuild"
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/${contentNameLower}-src"
BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/${contentNameLower}-build"
${ARGN} # Could override any of the above ..._DIR variables
)
# Pass source and binary dir variables back to the caller
set(${contentNameLower}_SOURCE_DIR "${${contentNameLower}_SOURCE_DIR}" PARENT_SCOPE)
set(${contentNameLower}_BINARY_DIR "${${contentNameLower}_BINARY_DIR}" PARENT_SCOPE)
# Don't set global properties, or record that we did this population, since
# this was a direct call outside of the normal declared details form.
# We only want to save values in the global properties for content that
# honours the hierarchical details mechanism so that projects are not
# robbed of the ability to override details set in nested projects.
return()
endif()
# No details provided, so assume they were saved from an earlier call
# to FetchContent_Declare(). Do a check that we haven't already
# populated this content before in case the caller forgot to check.
FetchContent_GetProperties(${contentName})
if(${contentNameLower}_POPULATED)
message(FATAL_ERROR "Content ${contentName} already populated in ${${contentNameLower}_SOURCE_DIR}")
endif()
string(TOUPPER ${contentName} contentNameUpper)
set(FETCHCONTENT_SOURCE_DIR_${contentNameUpper}
"${FETCHCONTENT_SOURCE_DIR_${contentNameUpper}}"
CACHE PATH "When not empty, overrides where to find pre-populated content for ${contentName}")
if(FETCHCONTENT_SOURCE_DIR_${contentNameUpper})
# The source directory has been explicitly provided in the cache,
# so no population is required
set(${contentNameLower}_SOURCE_DIR "${FETCHCONTENT_SOURCE_DIR_${contentNameUpper}}")
set(${contentNameLower}_BINARY_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-build")
elseif(FETCHCONTENT_FULLY_DISCONNECTED)
# Bypass population and assume source is already there from a previous run
set(${contentNameLower}_SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-src")
set(${contentNameLower}_BINARY_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-build")
else()
# Support both a global "disconnect all updates" and a per-content
# update test (either one being set disables updates for this content).
option(FETCHCONTENT_UPDATES_DISCONNECTED_${contentNameUpper}
"Enables UPDATE_DISCONNECTED behavior just for population of ${contentName}")
if(FETCHCONTENT_UPDATES_DISCONNECTED OR
FETCHCONTENT_UPDATES_DISCONNECTED_${contentNameUpper})
set(disconnectUpdates True)
else()
set(disconnectUpdates False)
endif()
if(FETCHCONTENT_QUIET)
set(quietFlag QUIET)
else()
unset(quietFlag)
endif()
__FetchContent_getSavedDetails(${contentName} contentDetails)
if("${contentDetails}" STREQUAL "")
message(FATAL_ERROR "No details have been set for content: ${contentName}")
endif()
__FetchContent_directPopulate(
${contentNameLower}
${quietFlag}
UPDATE_DISCONNECTED ${disconnectUpdates}
SUBBUILD_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-subbuild"
SOURCE_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-src"
BINARY_DIR "${FETCHCONTENT_BASE_DIR}/${contentNameLower}-build"
# Put the saved details last so they can override any of the
# the options we set above (this can include SOURCE_DIR or
# BUILD_DIR)
${contentDetails}
)
endif()
__FetchContent_setPopulated(
${contentName}
${${contentNameLower}_SOURCE_DIR}
${${contentNameLower}_BINARY_DIR}
)
# Pass variables back to the caller. The variables passed back here
# must match what FetchContent_GetProperties() sets when it is called
# with just the content name.
set(${contentNameLower}_SOURCE_DIR "${${contentNameLower}_SOURCE_DIR}" PARENT_SCOPE)
set(${contentNameLower}_BINARY_DIR "${${contentNameLower}_BINARY_DIR}" PARENT_SCOPE)
set(${contentNameLower}_POPULATED True PARENT_SCOPE)
endfunction()
# Distributed under the OSI-approved BSD 3-Clause License. See accompanying
# file Copyright.txt or https://cmake.org/licensing for details.
cmake_minimum_required(VERSION ${CMAKE_VERSION})
# We name the project and the target for the ExternalProject_Add() call
# to something that will highlight to the user what we are working on if
# something goes wrong and an error message is produced.
project(${contentName}-populate NONE)
include(ExternalProject)
ExternalProject_Add(${contentName}-populate
${ARG_EXTRA}
SOURCE_DIR "${ARG_SOURCE_DIR}"
BINARY_DIR "${ARG_BINARY_DIR}"
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
# Try to find Valgrind headers and libraries.
#
# Usage of this module as follows:
# find_package(Valgrind)
#
# Variables used by this module, they can change the default behaviour and need
# to be set before calling find_package:
#
# VALGRIND_ROOT Set this variable to the root installation of valgrind if the
# module has problems finding the proper installation path.
#
# Variables defined by this module:
# Valgrind_FOUND System has valgrind
# Valgrind_INCLUDE_DIR where to find valgrind/memcheck.h, etc.
# Valgrind_EXECUTABLE the valgrind executable.
# Get hint from environment variable (if any)
if(NOT VALGRIND_ROOT AND DEFINED ENV{VALGRIND_ROOT})
set(VALGRIND_ROOT "$ENV{VALGRIND_ROOT}" CACHE PATH "Valgrind base directory location (optional, used for nonstandard installation paths)")
mark_as_advanced(VALGRIND_ROOT)
endif()
# Search path for nonstandard locations
if(VALGRIND_ROOT)
set(Valgrind_INCLUDE_PATH PATHS "${VALGRIND_ROOT}/include" NO_DEFAULT_PATH)
set(Valgrind_BINARY_PATH PATHS "${VALGRIND_ROOT}/bin" NO_DEFAULT_PATH)
endif()
find_path(Valgrind_INCLUDE_DIR valgrind HINTS ${Valgrind_INCLUDE_PATH})
find_program(Valgrind_EXECUTABLE NAMES valgrind PATH ${Valgrind_BINARY_PATH})
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Valgrind DEFAULT_MSG Valgrind_INCLUDE_DIR Valgrind_EXECUTABLE)
mark_as_advanced(Valgrind_INCLUDE_DIR Valgrind_EXECUTABLE)
if(NOT Valgrind_FOUND)
if(Valgrind_FIND_REQUIRED)
message(FATAL_ERROR "Valgrind required but it seems it has not be installed.")
endif()
else()
message(STATUS "Found Valgrind: ${Valgrind_EXECUTABLE}")
endif()
## FetchContent
`FetchContent.cmake` and `FetchContent/CMakeLists.cmake.in`
are copied from `cmake/3.11.0/share/cmake-3.11/Modules`.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(download_googltest)
if(CMAKE_VERSION VERSION_LESS 3.11)
# FetchContent is available since 3.11,
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
# so that it can be used in lower CMake versions.
message(STATUS "Use FetchContent provided by k2")
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()
include(FetchContent)
set(googletest_URL "https://github.com/google/googletest/archive/release-1.10.0.tar.gz")
set(googletest_HASH "SHA256=9dc9157a9a1551ec7a7e43daea9a694a0bb5fb8bec81235d8a1e6ef64c716dcb")
set(BUILD_GMOCK ON CACHE BOOL "" FORCE)
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
set(gtest_disable_pthreads ON CACHE BOOL "" FORCE)
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_Declare(googletest
URL ${googletest_URL}
URL_HASH ${googletest_HASH}
)
FetchContent_GetProperties(googletest)
if(NOT googletest_POPULATED)
message(STATUS "Downloading googletest")
FetchContent_Populate(googletest)
endif()
message(STATUS "googletest is downloaded to ${googletest_SOURCE_DIR}")
message(STATUS "googletest's binary dir is ${googletest_BINARY_DIR}")
if(APPLE)
set(CMAKE_MACOSX_RPATH ON) # to solve the following warning on macOS
endif()
#[==[
-- Generating done
Policy CMP0042 is not set: MACOSX_RPATH is enabled by default. Run "cmake
--help-policy CMP0042" for policy details. Use the cmake_policy command to
set the policy and suppress this warning.
MACOSX_RPATH is not specified for the following targets:
gmock
gmock_main
gtest
gtest_main
This warning is for project developers. Use -Wno-dev to suppress it.
]==]
add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL)
target_include_directories(gtest
INTERFACE
${googletest_SOURCE_DIR}/googletest/include
${googletest_SOURCE_DIR}/googlemock/include
)
endfunction()
download_googltest()
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(download_pybind11)
if(CMAKE_VERSION VERSION_LESS 3.11)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()
include(FetchContent)
set(pybind11_URL "https://github.com/pybind/pybind11/archive/v2.6.0.tar.gz")
set(pybind11_HASH "SHA256=90b705137b69ee3b5fc655eaca66d0dc9862ea1759226f7ccd3098425ae69571")
set(double_quotes "\"")
set(dollar "\$")
set(semicolon "\;")
if(NOT WIN32)
FetchContent_Declare(pybind11
URL ${pybind11_URL}
URL_HASH ${pybind11_HASH}
)
else()
FetchContent_Declare(pybind11
URL ${pybind11_URL}
URL_HASH ${pybind11_HASH}
)
endif()
FetchContent_GetProperties(pybind11)
if(NOT pybind11_POPULATED)
message(STATUS "Downloading pybind11")
FetchContent_Populate(pybind11)
endif()
message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}")
add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR} EXCLUDE_FROM_ALL)
endfunction()
download_pybind11()
#
# This file is copied from
# https://github.com/pytorch/pytorch/blob/master/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake
#
#
# Synopsis:
# CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures])
# -- Selects GPU arch flags for nvcc based on target_CUDA_architectures
# target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...)
# - "Auto" detects local machine GPU compute arch at runtime.
# - "Common" and "All" cover common and entire subsets of architectures
# ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
# NAME: Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal Volta Turing Ampere
# NUM: Any number. Only those pairs are currently accepted by NVCC though:
# 3.5 3.7 5.0 5.2 5.3 6.0 6.2 7.0 7.2 7.5 8.0
# Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
# Additionally, sets ${out_variable}_readable to the resulting numeric list
# Example:
# CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell)
# LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
#
# More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA
#
if(CMAKE_CUDA_COMPILER_LOADED OR DEFINED CMAKE_CUDA_COMPILER_ID) # CUDA as a language
if(CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA"
AND CMAKE_CUDA_COMPILER_VERSION MATCHES "^([0-9]+\\.[0-9]+)")
set(CUDA_VERSION "${CMAKE_MATCH_1}")
endif()
endif()
# See: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-feature-list
# This list will be used for CUDA_ARCH_NAME = All option
set(CUDA_KNOWN_GPU_ARCHITECTURES "Kepler" "Maxwell")
# This list will be used for CUDA_ARCH_NAME = Common option (enabled by default)
set(CUDA_COMMON_GPU_ARCHITECTURES "3.5" "5.0")
if(CUDA_VERSION VERSION_LESS "7.0")
set(CUDA_LIMIT_GPU_ARCHITECTURE "5.2")
endif()
# This list is used to filter CUDA archs when autodetecting
set(CUDA_ALL_GPU_ARCHITECTURES "3.5" "5.0")
if(CUDA_VERSION VERSION_GREATER "6.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Kepler+Tegra" "Kepler+Tesla" "Maxwell+Tegra")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2")
if(CUDA_VERSION VERSION_LESS "8.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2+PTX")
set(CUDA_LIMIT_GPU_ARCHITECTURE "6.0")
endif()
endif()
if(CUDA_VERSION VERSION_GREATER "7.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Pascal")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.0" "6.1")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "6.0" "6.1" "6.2")
if(CUDA_VERSION VERSION_LESS "9.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.2+PTX")
set(CUDA_LIMIT_GPU_ARCHITECTURE "7.0")
endif()
endif ()
if(CUDA_VERSION VERSION_GREATER "8.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Volta")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.0")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.0" "7.2")
if(CUDA_VERSION VERSION_LESS "10.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.2+PTX")
set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0")
endif()
endif()
if(CUDA_VERSION VERSION_GREATER "9.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Turing")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.5")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.5")
if(CUDA_VERSION VERSION_LESS "11.0")
set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.5+PTX")
endif()
endif()
if(CUDA_VERSION VERSION_GREATER "10.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.0")
if(CUDA_VERSION VERSION_LESS "11.1")
set(CUDA_LIMIT_GPU_ARCHITECTURE "8.6")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0+PTX")
endif()
endif()
if(CUDA_VERSION VERSION_GREATER "11.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6" "8.6+PTX")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6")
if(CUDA_VERSION VERSION_LESS "12.0")
set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0")
endif()
endif()
################################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled)
# Usage:
# CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE)
#
function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE)
if(NOT CUDA_GPU_DETECT_OUTPUT)
if(CMAKE_CUDA_COMPILER_LOADED OR DEFINED CMAKE_CUDA_COMPILER_ID) # CUDA as a language
set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cu")
else()
set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cpp")
endif()
file(WRITE ${file} ""
"#include <cuda_runtime.h>\n"
"#include <cstdio>\n"
"int main()\n"
"{\n"
" int count = 0;\n"
" if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
" if (count == 0) return -1;\n"
" for (int device = 0; device < count; ++device)\n"
" {\n"
" cudaDeviceProp prop;\n"
" if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
" std::printf(\"%d.%d \", prop.major, prop.minor);\n"
" }\n"
" return 0;\n"
"}\n")
if(CMAKE_CUDA_COMPILER_LOADED OR DEFINED CMAKE_CUDA_COMPILER_ID) # CUDA as a language
try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file}
RUN_OUTPUT_VARIABLE compute_capabilities)
else()
try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file}
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}"
LINK_LIBRARIES ${CUDA_LIBRARIES}
RUN_OUTPUT_VARIABLE compute_capabilities)
endif()
# Filter unrelated content out of the output.
string(REGEX MATCHALL "[0-9]+\\.[0-9]+" compute_capabilities "${compute_capabilities}")
if(run_result EQUAL 0)
string(REPLACE "2.1" "2.1(2.0)" compute_capabilities "${compute_capabilities}")
set(CUDA_GPU_DETECT_OUTPUT ${compute_capabilities}
CACHE INTERNAL "Returned GPU architectures from detect_gpus tool" FORCE)
endif()
endif()
if(NOT CUDA_GPU_DETECT_OUTPUT)
message(STATUS "Automatic GPU detection failed. Building for common architectures.")
set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE)
else()
# Filter based on CUDA version supported archs
set(CUDA_GPU_DETECT_OUTPUT_FILTERED "")
separate_arguments(CUDA_GPU_DETECT_OUTPUT)
foreach(ITEM IN ITEMS ${CUDA_GPU_DETECT_OUTPUT})
if(CUDA_LIMIT_GPU_ARCHITECTURE AND (ITEM VERSION_GREATER CUDA_LIMIT_GPU_ARCHITECTURE OR
ITEM VERSION_EQUAL CUDA_LIMIT_GPU_ARCHITECTURE))
list(GET CUDA_COMMON_GPU_ARCHITECTURES -1 NEWITEM)
string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${NEWITEM}")
else()
string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${ITEM}")
endif()
endforeach()
set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT_FILTERED} PARENT_SCOPE)
endif()
endfunction()
################################################################################################
# Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list
# Usage:
# SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs])
function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
set(CUDA_ARCH_LIST "${ARGN}")
if("X${CUDA_ARCH_LIST}" STREQUAL "X" )
set(CUDA_ARCH_LIST "Auto")
endif()
set(cuda_arch_bin)
set(cuda_arch_ptx)
if("${CUDA_ARCH_LIST}" STREQUAL "All")
set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES})
elseif("${CUDA_ARCH_LIST}" STREQUAL "Common")
set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES})
elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto")
CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST)
message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}")
endif()
# Now process the list and look for names
string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
list(REMOVE_DUPLICATES CUDA_ARCH_LIST)
foreach(arch_name ${CUDA_ARCH_LIST})
set(arch_bin)
set(arch_ptx)
set(add_ptx FALSE)
# Check to see if we are compiling PTX
if(arch_name MATCHES "(.*)\\+PTX$")
set(add_ptx TRUE)
set(arch_name ${CMAKE_MATCH_1})
endif()
if(arch_name MATCHES "^([0-9]\\.[0-9](\\([0-9]\\.[0-9]\\))?)$")
set(arch_bin ${CMAKE_MATCH_1})
set(arch_ptx ${arch_bin})
else()
# Look for it in our list of known architectures
if(${arch_name} STREQUAL "Kepler+Tesla")
set(arch_bin 3.7)
elseif(${arch_name} STREQUAL "Kepler")
set(arch_bin 3.5)
set(arch_ptx 3.5)
elseif(${arch_name} STREQUAL "Maxwell+Tegra")
set(arch_bin 5.3)
elseif(${arch_name} STREQUAL "Maxwell")
set(arch_bin 5.0 5.2)
set(arch_ptx 5.2)
elseif(${arch_name} STREQUAL "Pascal")
set(arch_bin 6.0 6.1)
set(arch_ptx 6.1)
elseif(${arch_name} STREQUAL "Volta")
set(arch_bin 7.0 7.0)
set(arch_ptx 7.0)
elseif(${arch_name} STREQUAL "Turing")
set(arch_bin 7.5)
set(arch_ptx 7.5)
elseif(${arch_name} STREQUAL "Ampere")
set(arch_bin 8.0)
set(arch_ptx 8.0)
else()
message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS")
endif()
endif()
if(NOT arch_bin)
message(SEND_ERROR "arch_bin wasn't set for some reason")
endif()
list(APPEND cuda_arch_bin ${arch_bin})
if(add_ptx)
if (NOT arch_ptx)
set(arch_ptx ${arch_bin})
endif()
list(APPEND cuda_arch_ptx ${arch_ptx})
endif()
endforeach()
# remove dots and convert to lists
string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}")
string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}")
string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}")
if(cuda_arch_bin)
list(REMOVE_DUPLICATES cuda_arch_bin)
endif()
if(cuda_arch_ptx)
list(REMOVE_DUPLICATES cuda_arch_ptx)
endif()
set(nvcc_flags "")
set(nvcc_archs_readable "")
# Tell NVCC to add binaries for the specified GPUs
foreach(arch ${cuda_arch_bin})
if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
# User explicitly specified ARCH for the concrete CODE
list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
else()
# User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
list(APPEND nvcc_archs_readable sm_${arch})
endif()
endforeach()
# Tell NVCC to add PTX intermediate code for the specified architectures
foreach(arch ${cuda_arch_ptx})
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
list(APPEND nvcc_archs_readable compute_${arch})
endforeach()
string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
set(${out_variable} ${nvcc_flags} PARENT_SCOPE)
set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
endfunction()
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
# PYTHON_EXECUTABLE is set by pybind11.cmake
message(STATUS "Python executable: ${PYTHON_EXECUTABLE}")
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import os; import torch; print(os.path.dirname(torch.__file__))"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE TORCH_DIR
)
list(APPEND CMAKE_PREFIX_PATH "${TORCH_DIR}")
find_package(Torch REQUIRED)
# set the global CMAKE_CXX_FLAGS so that
# optimized_transducer uses the same abi flag as PyTorch
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
if(OT_WITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${TORCH_CXX_FLAGS}")
endif()
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[0])"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE OT_TORCH_VERSION_MAJOR
)
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[1])"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE OT_TORCH_VERSION_MINOR
)
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__)"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE TORCH_VERSION
)
message(STATUS "PyTorch version: ${TORCH_VERSION}")
if(OT_WITH_CUDA)
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.version.cuda)"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE TORCH_CUDA_VERSION
)
message(STATUS "PyTorch cuda version: ${TORCH_CUDA_VERSION}")
if(NOT CUDA_VERSION VERSION_EQUAL TORCH_CUDA_VERSION)
message(FATAL_ERROR
"PyTorch ${TORCH_VERSION} is compiled with CUDA ${TORCH_CUDA_VERSION}.\n"
"But you are using CUDA ${CUDA_VERSION} to compile optimized_transducer.\n"
"Please try to use the same CUDA version for PyTorch and optimized_transducer.\n"
"**You can remove this check if you are sure this will not cause "
"problems**\n"
)
endif()
# Solve the following error for NVCC:
# unknown option `-Wall`
#
# It contains only some -Wno-* flags, so it is OK
# to set them to empty
set_property(TARGET torch_cuda
PROPERTY
INTERFACE_COMPILE_OPTIONS ""
)
set_property(TARGET torch_cpu
PROPERTY
INTERFACE_COMPILE_OPTIONS ""
)
endif()
add_subdirectory(csrc)
add_subdirectory(python)
include_directories(${CMAKE_SOURCE_DIR})
set(srcs
mutual_information_cpu.cc
)
add_library(mutual_information_core ${srcs})
target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES})
if(FT_WITH_CUDA)
set(cuda_srcs mutual_information_cuda.cu)
add_library(mutual_information_core_cuda ${cuda_srcs})
target_link_libraries(mutual_information_core_cuda PUBLIC ${TORCH_LIBRARIES})
target_include_directories(mutual_information_core_cuda PUBLIC ${PYTHON_INCLUDE_DIRS})
target_link_libraries(mutual_information_core PUBLIC mutual_information_core_cuda)
endif()
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#include <torch/extension.h>
#include <cmath>
#include <vector>
#ifdef __CUDA_ARCH__
#define FT_CUDA_HOSTDEV __host__ __device__
#else
#define FT_CUDA_HOSTDEV
#endif
namespace fast_rnnt {
FT_CUDA_HOSTDEV inline double LogAdd(double x, double y) {
double diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
// returns log(exp(x) + exp(y)).
FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
/*
Forward of mutual_information. See also comment of `mutual_information`
in ../pyhton/fast_rnnt/mutual_information.py. This is the core recursion
in the sequence-to-sequence mutual information computation.
@param px Tensor of shape [B][S][T + 1] if not modified, [B][S][T] if
modified. `modified` can be worked out from this. In not-modified case,
it can be thought of as the log-odds ratio of generating the next x in
the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of
lengths (s, t), divided by the prior probability of generating x_s.
(See mutual_information.py for more info).
@param py The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
@param p This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
if not modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
... treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b]
equals [s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last)
of the x and y sequences that we should process.
Alternatively, may be a tensor of shape [0][0] and type
int64_t; the elements will default to (0, 0, S, T).
@return A tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch::Tensor MutualInformationCpu(
torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
torch::Tensor MutualInformationCuda(
torch::Tensor px, // [B][S][T+1] if !modified, [B][S][T] if modified.
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std::vector<torch::Tensor>
MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad);
std::vector<torch::Tensor> MutualInformationBackwardCuda(
torch::Tensor px, torch::Tensor py, torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad, bool overwrite_ans_grad);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/mutual_information.h"
namespace fast_rnnt {
// forward of mutual_information. See """... """ comment of
// `mutual_information_recursion` in
// in k2/python/k2/mutual_information.py for documentation of the
// behavior of this function.
// px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out
// `modified` from this.
// py: of shape [B, S+1, T]
// boundary: of shape [B, 4], containing (s_begin, t_begin, s_end, t_end)
// defaulting to (0, 0, S, T).
// p: of shape (S+1, T+1)
// Computes the recursion:
// if !modified:
// p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
// p[b,s,t-1] + py[b,s,t-1])
// if modified:
// p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
// p[b,s,t-1] + py[b,s,t-1])
// .. treating out-of-range elements as -infinity and with special cases:
// p[b, s_begin, t_begin] = 0.0
//
// and this function returns a tensor of shape (B,) consisting of elements
// p[b, s_end, t_end]
torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> opt_boundary,
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() &&
p.device().is_cpu(),
"inputs must be CPU tensors");
bool modified = (px.size(2) == py.size(2));
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), S = px.size(1), T = py.size(2);
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
auto boundary = opt_boundary.value_or(
torch::tensor({0, 0, S, T},
torch::dtype(torch::kInt64).device(torch::kCPU))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts);
AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(),
py_a = py.accessor<scalar_t, 3>(), p_a = p.accessor<scalar_t, 3>();
auto boundary_a = boundary.accessor<int64_t, 2>();
auto ans_a = ans.accessor<scalar_t, 1>();
int t_offset = (modified ? -1 : 0);
for (int b = 0; b < B; b++) {
int s_begin = boundary_a[b][0];
int t_begin = boundary_a[b][1];
int s_end = boundary_a[b][2];
int t_end = boundary_a[b][3];
p_a[b][s_begin][t_begin] = 0.0;
if (modified) {
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] = -std::numeric_limits<scalar_t>::infinity();
} else {
// note: t_offset = 0 so don't need t_begin + t_offset below.
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] =
p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
}
for (int t = t_begin + 1; t <= t_end; ++t)
p_a[b][s_begin][t] =
p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
for (int s = s_begin + 1; s <= s_end; ++s) {
scalar_t p_s_t1 = p_a[b][s][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t) {
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(
// p_a[b][s - 1][t + t_offset] + px_a[b][s -1][t + t_offset],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a[b][s][t] = p_s_t1 = LogAdd(p_a[b][s - 1][t + t_offset] +
px_a[b][s - 1][t + t_offset],
p_s_t1 + py_a[b][s][t - 1]);
}
}
ans_a[b] = p_a[b][s_end][t_end];
}
}));
return ans;
}
// backward of mutual_information. Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std::vector<torch::Tensor>
MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> opt_boundary,
torch::Tensor p, torch::Tensor ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional.");
bool modified = (px.size(2) == py.size(2));
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() &&
p.device().is_cpu() && ans_grad.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), S = px.size(1), T = py.size(2);
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
auto boundary = opt_boundary.value_or(
torch::tensor({0, 0, S, T},
torch::dtype(torch::kInt64).device(torch::kCPU))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);
bool has_boundary = opt_boundary.has_value();
int T1 = T + (modified ? 0 : 1);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts)
: torch::empty({B, S, T1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts)
: torch::empty({B, S + 1, T}, opts));
AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(), p_a = p.accessor<scalar_t, 3>(),
p_grad_a = p_grad.accessor<scalar_t, 3>(),
px_grad_a = px_grad.accessor<scalar_t, 3>(),
py_grad_a = py_grad.accessor<scalar_t, 3>();
auto ans_grad_a = ans_grad.accessor<scalar_t, 1>();
auto boundary_a = boundary.accessor<int64_t, 2>();
int t_offset = (modified ? -1 : 0);
for (int b = 0; b < B; b++) {
int s_begin = boundary_a[b][0];
int t_begin = boundary_a[b][1];
int s_end = boundary_a[b][2];
int t_end = boundary_a[b][3];
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a[b][s_end][t_end] = ans_grad_a[b];
for (int s = s_end; s > s_begin; --s) {
for (int t = t_end; t > t_begin; --t) {
// The s,t indexes correspond to
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(
// p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t + t_offset] +
px_a[b][s - 1][t + t_offset],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total = p_a[b][s][t];
if (total - total != 0)
total = 0;
scalar_t term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t];
scalar_t term1_grad, term2_grad;
if (term1_deriv - term1_deriv == 0.0) {
term1_grad = term1_deriv * grad;
term2_grad = term2_deriv * grad;
} else {
// could happen if total == -inf
term1_grad = term2_grad = 0.0;
}
px_grad_a[b][s - 1][t + t_offset] = term1_grad;
p_grad_a[b][s - 1][t + t_offset] = term1_grad;
py_grad_a[b][s][t - 1] = term2_grad;
p_grad_a[b][s][t - 1] += term2_grad;
}
}
for (int t = t_end; t > t_begin; --t) {
// Backprop for:
// p_a[b][s_begin][t] =
// p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t this_p_grad = p_grad_a[b][s_begin][t];
p_grad_a[b][s_begin][t - 1] += this_p_grad;
py_grad_a[b][s_begin][t - 1] = this_p_grad;
}
if (!modified) {
for (int s = s_end; s > s_begin; --s) {
// Backprop for:
// p_a[b][s][t_begin] =
// p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t this_p_grad = p_grad_a[b][s][t_begin];
p_grad_a[b][s - 1][t_begin] += this_p_grad;
px_grad_a[b][s - 1][t_begin] = this_p_grad;
}
} // else these were all -infinity's and there is nothing to
// backprop.
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b];
if (fabs(grad_ratio - 1.0) > 0.01) {
// K2_LOG(WARNING)
//<< "Warning: mutual_information backprop: expected these "
//<< "numbers to be the same:"
//<< static_cast<float>(p_grad_a[b][s_begin][t_begin]) << " vs "
//<< static_cast<float>(ans_grad_a[b]);
}
}
}
}));
return std::vector<torch::Tensor>({px_grad, py_grad});
}
} // namespace fast_rnnt
#include <torch/extension.h> /**
#include <c10/cuda/CUDAStream.h> // for getCurrentCUDAStream() * @copyright
#include <cooperative_groups.h> * Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
#include <cmath> // for INFINITY *
* @copyright
* See LICENSE for clarification regarding multiple authors
// returns log(exp(x) + exp(y)). *
__forceinline__ __device__ double LogAdd(double x, double y) { * Licensed under the Apache License, Version 2.0 (the "License");
double diff; * you may not use this file except in compliance with the License.
if (x < y) { * You may obtain a copy of the License at
diff = x - y; *
x = y; * http://www.apache.org/licenses/LICENSE-2.0
} else { *
diff = y - x; * Unless required by applicable law or agreed to in writing, software
} * distributed under the License is distributed on an "AS IS" BASIS,
// diff is negative. x is now the larger one. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
if (diff - diff != 0) * See the License for the specific language governing permissions and
return x; // x and y are probably -inf. Return the larger one. * limitations under the License.
else */
return x + log1p(exp(diff));
}
// returns log(exp(x) + exp(y)).
__forceinline__ __device__ float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
#include <c10/cuda/CUDAStream.h> // for getCurrentCUDAStream()
#include <cooperative_groups.h>
#include "fast_rnnt/csrc/mutual_information.h"
namespace fast_rnnt {
/* /*
Forward of mutual_information. Each thread block computes blocks of the 'p' Forward of mutual_information. Each thread block computes blocks of the 'p'
array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32). array of (s, t) shape equal to (BLOCK_SIZE, BLOCK_SIZE), e.g. (32, 32).
...@@ -55,13 +40,14 @@ __forceinline__ __device__ float LogAdd(float x, float y) { ...@@ -55,13 +40,14 @@ __forceinline__ __device__ float LogAdd(float x, float y) {
is because we assume BLOCK_SIZE + 1 <= 64 in some data-loading is because we assume BLOCK_SIZE + 1 <= 64 in some data-loading
code). code).
Args: Args:
px: Tensor of shape [B][S][T + 1]; contains the log-odds ratio of px: Tensor of shape [B][S][T + 1], if !modified; [B][S][T] if modified;
generating the next x in the sequence, i.e. may be interpreted as the log-odds ratio of
xy[b][s][t] is the log of generating the next x in the sequence, i.e.
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s), xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of lengths i.e. the log-prob of generating x_s given subsequences of lengths
(s, t), divided by the prior probability of generating x_s. (See (s, t), divided by the prior probability of generating x_s. (See
mutual_information.py for more info). mutual_information.py for more info).
py: The log-odds ratio of generating the next y in the sequence. py: The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T] Shape [B][S + 1][T]
p: This function writes to p[b][s][t] the mutual information between p: This function writes to p[b][s][t] the mutual information between
...@@ -71,10 +57,14 @@ __forceinline__ __device__ float LogAdd(float x, float y) { ...@@ -71,10 +57,14 @@ __forceinline__ __device__ float LogAdd(float x, float y) {
in the case where s_begin == t_begin == 0: in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0 p[b,0,0] = 0.0
if not `modified`:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t], p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1]) (eq. 0) p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
if s > 0 or t > 0, if `modified`:
treating values with any -1 index as -infinity. p[b,s,t] = log_add(p[b,s-1,t-t] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0. .. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
boundary: If set, a tensor of shape [B][4] of type int64_t, which boundary: If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b] equals contains, where for each batch element b, boundary[b] equals
...@@ -95,29 +85,32 @@ __forceinline__ __device__ float LogAdd(float x, float y) { ...@@ -95,29 +85,32 @@ __forceinline__ __device__ float LogAdd(float x, float y) {
be at least 128. be at least 128.
*/ */
template <typename scalar_t, template <typename scalar_t,
int BLOCK_SIZE> // e.g. BLOCK_SIZE == 16 or 32. int BLOCK_SIZE> // e.g. BLOCK_SIZE == 16 or 32.
__global__ __global__ void mutual_information_kernel(
void mutual_information_kernel( // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1
torch::PackedTensorAccessor32<scalar_t, 3> px, // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1 torch::PackedTensorAccessor32<scalar_t, 3> px,
torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T. torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T.
torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. This is an output. // B, S + 1, T + 1. This is an output.
torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) torch::PackedTensorAccessor32<scalar_t, 3> p,
torch::PackedTensorAccessor32<scalar_t, 1> ans, // [B] // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
int iter) { // This kernel is sequentially called with 'iter' = 0, 1, 2 and so on, torch::PackedTensorAccessor32<int64_t, 2> boundary,
// up to num_iters - 1 where torch::PackedTensorAccessor32<scalar_t, 1> ans, // [B]
// num_iters = num_s_blocks + num_t_blocks - 1 int iter) { // This kernel is sequentially called with 'iter' = 0, 1, 2 and
// num_s_blocks = S / BLOCK_SIZE + 1 // so on, up to num_iters - 1 where num_iters = num_s_blocks +
// num_t_blocks = T / BLOCK_SIZE + 1 // num_t_blocks - 1 num_s_blocks = S / BLOCK_SIZE + 1
// so that each group depends on the previous group... // num_t_blocks = T / BLOCK_SIZE + 1
const int B = px.size(0), // so that each group depends on the previous group...
S = px.size(1), const int B = px.size(0), S = px.size(1), T = py.size(2);
T = py.size(2); const bool modified = (px.size(2) == T);
const int t_offset = (modified ? -1 : 0); // see CPU code to understand.
// num_s_blocks and num_t_blocks are the number of blocks we need to cover the // num_s_blocks and num_t_blocks are the number of blocks we need to cover the
// array of size (S, T) with blocks of this size, in the s and t directions // array of size (S, T) with blocks of this size, in the s and t directions
// respectively. // respectively.
// You can read the following expressions as simplifications of, for example, // You can read the following expressions as simplifications of, for example,
// num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE, // num_s_blocks = ((S + 1) + BLOCK_SIZE - 1) / BLOCK_SIZE,
// i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T + 1). // i.e. rounding-up division of (S + 1) by BLOCK_SIZE, and the same for (T +
// 1).
const int num_s_blocks = S / BLOCK_SIZE + 1; const int num_s_blocks = S / BLOCK_SIZE + 1;
//, num_t_blocks = T / BLOCK_SIZE + 1; //, num_t_blocks = T / BLOCK_SIZE + 1;
...@@ -134,16 +127,16 @@ void mutual_information_kernel( ...@@ -134,16 +127,16 @@ void mutual_information_kernel(
int num_blocks_this_iter = min(iter + 1, num_s_blocks); int num_blocks_this_iter = min(iter + 1, num_s_blocks);
// For the block with s_block_begin == 0 and t_block_begin == 0 (for // For the block with s_block_begin == 0 and t_block_begin == 0 (for
// easy illustration), px_buf[s][t] will contain exp(px[s - 1][t]); or 0 // easy illustration), px_buf[s][t] will contain px[s - 1][t + t_offset]; or
// for out-of-range indexes into px. // -infinity. for out-of-range indexes into px. Likewise, py_buf[s][t] will
// Likewise, py_buf[s][t] will contain exp(py[s][t - 1]). // contain (py[s][t - 1]).
__shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE],
py_buf[BLOCK_SIZE][BLOCK_SIZE]; py_buf[BLOCK_SIZE][BLOCK_SIZE];
// p_buf[s][t] == exp(p[s+s_block_begin-1][t+t_block_begin-1] - normalizer). // p_buf[s][t] == p[s+s_block_begin-1][t+t_block_begin-1]
// 1st row/col of p_buf correspond to the previously computed blocks (lower // 1st row/col of p_buf correspond to the previously computed blocks (lower
// `iter`), or to negative indexes into p. So, for the origin block, // `iter`), or to negative indexes into p. So, for the origin block,
// p_buf[s][t] corresponds to exp(p[s - 1][t - 1] - normalizer); or 0 for // p_buf[s][t] corresponds to p[s - 1][t - 1]; or -inf for
// out-of-range values. // out-of-range values.
__shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1]; __shared__ scalar_t p_buf[BLOCK_SIZE + 1][BLOCK_SIZE + 1];
...@@ -165,7 +158,7 @@ void mutual_information_kernel( ...@@ -165,7 +158,7 @@ void mutual_information_kernel(
batch_block_iter < B * num_blocks_this_iter; batch_block_iter < B * num_blocks_this_iter;
batch_block_iter += gridDim.x) { batch_block_iter += gridDim.x) {
int block = batch_block_iter / B, int block = batch_block_iter / B,
b = batch_block_iter % B; // b is the index into the batch b = batch_block_iter % B; // b is the index into the batch
// Note: `block` can be no greater than `iter` because num_blocks_this_iter // Note: `block` can be no greater than `iter` because num_blocks_this_iter
// <= iter + 1, i.e. iter >= num_blocks_this_iter - 1; and // <= iter + 1, i.e. iter >= num_blocks_this_iter - 1; and
...@@ -176,15 +169,13 @@ void mutual_information_kernel( ...@@ -176,15 +169,13 @@ void mutual_information_kernel(
__syncthreads(); __syncthreads();
if (boundary.size(0) != 0 && threadIdx.x < 4) if (threadIdx.x < 4)
boundary_buf[threadIdx.x] = boundary[b][threadIdx.x]; boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads(); __syncthreads();
int s_begin = boundary_buf[0], int s_begin = boundary_buf[0], t_begin = boundary_buf[1],
t_begin = boundary_buf[1], s_end = boundary_buf[2], t_end = boundary_buf[3];
s_end = boundary_buf[2],
t_end = boundary_buf[3];
s_block_begin += s_begin; s_block_begin += s_begin;
t_block_begin += t_begin; t_block_begin += t_begin;
...@@ -200,95 +191,61 @@ void mutual_information_kernel( ...@@ -200,95 +191,61 @@ void mutual_information_kernel(
if (block_S <= 0 || block_T <= 0) if (block_S <= 0 || block_T <= 0)
continue; continue;
// Load px_buf and py_buf. We exponentiate; the assumption is that they // Load px_buf and py_buf.
// most likely won't overflow or underflow, but if they do overflow we'll
// detect it later; we'll also detect certain kinds of underflow.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin,
s = s_in_block + s_block_begin, t_off = t + t_offset;
t = t_in_block + t_block_begin;
// comparing as unsigned int makes sure the index is nonnegative. // comparing as unsigned int makes sure the index is nonnegative.
// Caution: if s_begin > 0 or t_begin > 0 we may end up loading some px and // Caution: if s_begin > 0 or t_begin > 0 we may end up loading some px
// py values that are outside the proper boundaries that we need, but // and py values that are outside the proper boundaries that we need, but
// the corresponding p_buf values will end up being 0 so this won't matter. // the corresponding p_buf values will end up being 0 so this won't
scalar_t this_px = 0.0; // matter.
if (s > s_begin && s <= s_end && t <= t_end) scalar_t this_px = -INFINITY;
this_px = exp(px[b][s - 1][t]); // Below, "&& t <= t_end" can be interpreted as:
// "&& (modified ? t_off < t_end : t_off <= t_end)
// [since px's last valid index is t_end - 1 if modified, else t_end.
if (s > s_begin && s <= s_end && t_off >= t_begin && t <= t_end)
this_px = px[b][s - 1][t_off];
px_buf[s_in_block][t_in_block] = this_px; px_buf[s_in_block][t_in_block] = this_px;
scalar_t this_py = 0.0;
scalar_t this_py = -INFINITY;
if (t > t_begin && t <= t_end && s <= s_end) if (t > t_begin && t <= t_end && s <= s_end)
this_py = exp(py[b][s][t - 1]); this_py = py[b][s][t - 1];
py_buf[s_in_block][t_in_block] = this_py; py_buf[s_in_block][t_in_block] = this_py;
} }
// Load the 1st row and 1st column of p_buf.
// Load the 1st row and 1st column of p_buf (except element[0][0] is not // This is the context from previously computed blocks of the
// needed). This is the context from previously computed blocks of the // image. Remember: p_buf[s][t] will correspond to p[s + s_block_begin -
// image. Remember: p_buf[s][t] will correspond to exp(p[s + s_block_begin - // 1][t + t_block_begin - 1]
// 1][t + t_block_begin - 1] - normalizer.
if (threadIdx.x <= BLOCK_SIZE) { if (threadIdx.x <= BLOCK_SIZE) {
// s_in_p_buf are simply the indexes into p_buf // s_in_p_buf and t_in_pbuf are simply the indexes into p_buf
int s_in_p_buf = threadIdx.x, int s_in_p_buf = threadIdx.x, t_in_p_buf = 0,
t_in_p_buf = 0,
s = s_in_p_buf + s_block_begin - 1, s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1; t = t_in_p_buf + t_block_begin - 1;
scalar_t this_p = -INFINITY; scalar_t this_p = -INFINITY;
if (s >= s_begin && s <= s_end && if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end)
t >= t_begin && t <= t_end)
this_p = p[b][s][t]; this_p = p[b][s][t];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]); */
p_buf[s_in_p_buf][t_in_p_buf] = this_p; p_buf[s_in_p_buf][t_in_p_buf] = this_p;
} else if (static_cast<unsigned int>(int(threadIdx.x) - 64) <= } else if (static_cast<unsigned int>(static_cast<int>(threadIdx.x) - 64) <=
static_cast<unsigned int>(BLOCK_SIZE)) { static_cast<unsigned int>(BLOCK_SIZE)) {
// Another warp handles the other leg. Checking as unsigned // Another warp handles the other leg. Checking as unsigned
// tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE // tests that threadIdx.x - 64 is both >= 0 and <= BLOCK_SIZE
int s_in_p_buf = 0, int s_in_p_buf = 0, t_in_p_buf = static_cast<int>(threadIdx.x) - 64,
t_in_p_buf = (int)threadIdx.x - 64,
s = s_in_p_buf + s_block_begin - 1, s = s_in_p_buf + s_block_begin - 1,
t = t_in_p_buf + t_block_begin - 1; t = t_in_p_buf + t_block_begin - 1;
scalar_t this_p = -INFINITY; scalar_t this_p = -INFINITY;
if (s >= s_begin && s <= s_end && if (s >= s_begin && s <= s_end && t >= t_begin && t <= t_end)
t >= t_begin && t <= t_end)
this_p = p[b][s][t]; this_p = p[b][s][t];
/*printf("p[%d][%d][%d] = %f, threadIdx.x = %d, px = %f, py = %f\n", b, s, t, (float)this_p, (int)threadIdx.x,
(float)px_buf[s_in_p_buf][t_in_p_buf], (float)py_buf[s_in_p_buf][t_in_p_buf]);*/
p_buf[s_in_p_buf][t_in_p_buf] = this_p; p_buf[s_in_p_buf][t_in_p_buf] = this_p;
} }
__syncthreads(); __syncthreads();
// We read p_buf in log-space; we now subtract 'normalizer', which
// mathematically could be any finite number, to get it in a range close to
// zero, and then exponentiate. We'll do everything in non-log space, for
// speed, and later take a log before we write out the data.
scalar_t normalizer = (is_origin_block ? 0.0 :
max(p_buf[0][1], p_buf[1][0]));
__syncthreads();
// Normalize and exponentiate the edge elements of p_buf, i.e. the elements
// where at one index is 0. The [0][0] element is special; we write 0.0,
// and we'll overwrite with 1.0 if there is a panic situation due to
// overflow.
if (threadIdx.x <= BLOCK_SIZE) {
// p_buf[0][0] is never used for its normal purpose; we set it to zero
// p_buf[0][0] = 0.0; <-- for search purposes.
// We'll later write an infinity there if something goes wrong, as a
// 'panic' indicator.
int s = threadIdx.x;
p_buf[s][0] = (s == 0 ? 0.0 :
exp(p_buf[s][0] - normalizer));
} else if (static_cast<unsigned int>(int(threadIdx.x) - 64) <
static_cast<unsigned int>(BLOCK_SIZE)) {
// if (threadidx.x - 64) >= 0 && (threadIdx.x - 64) < BLOCK_SIZE..
int t = (int)threadIdx.x - 64 + 1; // 0 < t <= BLOCK_SIZE
// this happens in a different warp so can be in parallel to the code above.
p_buf[0][t] = exp(p_buf[0][t] - normalizer);
}
__syncthreads();
// from here to the next __syncthreads(), only the 1st warp should be active // from here to the next __syncthreads(), only the 1st warp should be active
// so we shouldn't need to synchronize. (implicit within-warp // so we shouldn't need to synchronize. (implicit within-warp
// synchronization). // synchronization).
...@@ -299,19 +256,12 @@ void mutual_information_kernel( ...@@ -299,19 +256,12 @@ void mutual_information_kernel(
// to set p_buf to 1.0 = exp(0.0) if this is the "origin block", // to set p_buf to 1.0 = exp(0.0) if this is the "origin block",
// i.e. s == s_begin, t == t_begin. This corresponds to the // i.e. s == s_begin, t == t_begin. This corresponds to the
// probability of the pair of sequences of length (0, 0). // probability of the pair of sequences of length (0, 0).
p_buf[1][1] = (is_origin_block ? 1.0 : p_buf[1][1] =
p_buf[0][1] * px_buf[0][0] + (is_origin_block ? 0.0
p_buf[1][0] * py_buf[0][0]); : LogAdd(
} // px_buf has t_offset applied.
p_buf[0][1 + t_offset] + px_buf[0][0],
scalar_t p_buf_s1_t; // This is for an optimization to avoid one p_buf[1][0] + py_buf[0][0]));
// shared-memory read/write in the loop below. it
// represents p_buf[s + 1][t]; the first time we
// access this, it will be for t == 0, except for
// thread 0 when we first need it for t == 1.
if (threadIdx.x < BLOCK_SIZE) {
int s = threadIdx.x;
p_buf_s1_t = p_buf[s + 1][threadIdx.x == 0 ? 1 : 0];
} }
int s = threadIdx.x; int s = threadIdx.x;
...@@ -333,34 +283,23 @@ void mutual_information_kernel( ...@@ -333,34 +283,23 @@ void mutual_information_kernel(
static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) { static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) {
// p_buf is indexed by s + 1 and t + 1 because it has an extra initial // p_buf is indexed by s + 1 and t + 1 because it has an extra initial
// row and column for context from previous blocks. Taking into account // row and column for context from previous blocks. Taking into account
// the way these buffers relate to the tensors p, px and py, and // the way these buffers relate to the tensors p, px and py,
// ignoring `normalizer`, code below can be interpreted as follows, // can be interpreted as follows,
// writing sbb for s_block_begin and tbb for t_block_begin: // writing sbb for s_block_begin and tbb for t_block_begin:
// //
// p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] + px[s+sbb-1][t+tbb], // p[b][s+sbb][t+tbb] = LogAdd(p[b][s+sbb-1][t+tbb] +
// p[b][s+sbb][t+tbb-1] + py[s+sbb][t+tbb-1] // px[s+sbb-1][t+tbb],
// p[b][s+sbb][t+tbb-1] +
// py[s+sbb][t+tbb-1]
// //
// where you can see that apart from the offsets of tbb and sbb, this is // where you can see that apart from the offsets of tbb and sbb, this is
// the same as the recursion defined for p in // the same as the recursion defined for p in
// mutual_information.py:mutual_information_recursion(); and (eq. 0) above. // mutual_information.py:mutual_information_recursion(); and (eq. 0)
#if 1 // above.
p_buf[s + 1][t + 1] = p_buf[s][t + 1] * px_buf[s][t] + p_buf[s + 1][t] * py_buf[s][t];
// note: px_buf has t_offset applied..
/*printf("threadIdx.x = %d, i = %d, s = %d, t = %d, p_buf[s+1][t+1] = %f, p_buf[s][t+1] = %f, " p_buf[s + 1][t + 1] = LogAdd(p_buf[s][t + 1 + t_offset] + px_buf[s][t],
"px_buf[s][t] = %f, p_buf[s + 1][t] = %f, py_buf[s][t] = %f\n", p_buf[s + 1][t] + py_buf[s][t]);
(int)threadIdx.x, i, s, t, (float)p_buf[s+1][t+1], (float)p_buf[s][t+1],
(float)px_buf[s][t], (float)p_buf[s+1][t], (float)py_buf[s][t]);*/
#else
// This is an optimization of the statement above (the other half of
// this #if/#else) where we keep p_buf[s + 1][t] in a register to avoid
// the need for a load from shared memory.
p_buf_s1_t = p_buf[s][t + 1] * px_buf[s][t] + p_buf_s1_t * py_buf[s][t];
// The next time this thread reads p_buf_s1_t, t will be one greater,
// so p_buf_s1_t will contain p_buf[s + 1][t]. The first time this
// thread uses p_buf_s1_t is when t == 0, except for thread 0 where
// the 1st item accessed is for s == 0, t == 1.
p_buf[s + 1][t + 1] = p_buf_s1_t;
#endif
// We don't need to do __syncthreads() in this loop because all the // We don't need to do __syncthreads() in this loop because all the
// threads that are active are in the same warp. (However, in future, // threads that are active are in the same warp. (However, in future,
// if NVidia changes some things, we might need to sync here). // if NVidia changes some things, we might need to sync here).
...@@ -368,21 +307,13 @@ void mutual_information_kernel( ...@@ -368,21 +307,13 @@ void mutual_information_kernel(
} }
__syncthreads(); __syncthreads();
// Write out the data to p; check that nothing has gone out of numerical // Write out the data to p;
// range, and write 'panic' flag if it has.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
if (s_in_block < block_S && t_in_block < block_T) { if (s_in_block < block_S && t_in_block < block_T) {
scalar_t this_p = p_buf[s_in_block + 1][t_in_block + 1]; scalar_t this_p = p_buf[s_in_block + 1][t_in_block + 1];
p[b][s][t] = normalizer + log(this_p); p[b][s][t] = this_p;
// If this_p is infinity or NaN..
if (this_p - this_p != 0) {
// printf("[panic] threadIdx.x = %d, this_p = %f\n", (int)threadIdx.x, (float)this_p);
p_buf[0][0] = 1.0; // This is a "panic" flag.
}
} }
} }
...@@ -397,165 +328,112 @@ void mutual_information_kernel( ...@@ -397,165 +328,112 @@ void mutual_information_kernel(
// you could read block_S below as block_S - 1 + 1, meaning, // you could read block_S below as block_S - 1 + 1, meaning,
// it's the last index in a block of size block_S, but the indexes into // it's the last index in a block of size block_S, but the indexes into
// p_buf have a "+ 1". Likewise for block_T. // p_buf have a "+ 1". Likewise for block_T.
ans[b] = normalizer + log(p_buf[block_S][block_T]); ans[b] = p_buf[block_S][block_T];
}
}
if (p_buf[0][0] != 0.0) {
/*
// FOR DEBUGGING PANIC MODE:
if (threadIdx.x == 0)
printf("Panic flag set, value = %f\n", (float)p_buf[0][0]);
*/
// The "panic" flag is set. We need to re-do the computation using log-add.
// This time we won't use the buffers, we'll just load and save from main
// memory. This code should very rarely be reached; and anyway, caching
// should help us quite a bit.
int s_in_block = threadIdx.x;
for (int i = 0; i < block_S + block_T - 1; ++i) {
__syncwarp();
int t_in_block = i - s_in_block;
if (static_cast<unsigned int>(t_in_block) <
static_cast<unsigned int>(block_T) &&
s_in_block < block_S) {
int s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
scalar_t p_s1 = (s == s_begin ? -INFINITY : p[b][s - 1][t]),
this_px = (s == s_begin ? -INFINITY : px[b][s - 1][t]),
p_t1 = (t == t_begin ? -INFINITY : p[b][s][t - 1]),
this_py = (t == t_begin ? -INFINITY : py[b][s][t - 1]);
scalar_t this_p = LogAdd(p_s1 + this_px,
p_t1 + this_py);
if (i == 0 && is_origin_block)
this_p = 0.0;
p[b][s][t] = this_p;
}
}
__syncwarp();
if (threadIdx.x == 0) {
// Write `ans`, if this is the final (top-right) block in its sequence.
// This is only reached in the 'panic situation' where we had overflow.
if (s_block_begin + block_S - 1 == s_end &&
t_block_begin + block_T - 1 == t_end)
ans[b] = p[b][s_end][t_end];
} }
} }
} }
} }
// like exp(), but returns 0 if arg is inf/nan, or if result would be
// infinity or nan (note: this can happen for out-of-range elements
// when setting px_buf and py_buf is block_S != BLOCK_SIZE or
// block_T != BLOCK_SIZE, and it's a problem because even though
// out-of-range gradients are zero, if we multiply them by infinity
// we get NaN.
template <typename Real> __forceinline__ __device__ Real safe_exp(Real x) {
if (x - x != 0)
return 0;
else {
Real ans = exp(x);
if (ans - ans != 0.0)
return 0;
return ans;
}
}
/* /*
Backward of mutual_information. Backward of mutual_information.
If we were to write the forward pass in non-log space, it would be (ignoring The forward pass is:
edge cases), as follows... we'll prefix all the variable names with e, e.g. ep,
to clarify that it's the exp of the actual argument p:
ep[b][s][t] = ep[b][s - 1][t] * epx[b][s - 1][t] +
ep[b][s][t - 1] * epy[b][s][t - 1]. (eq. 1)
(A)
First we consider the part of the backprop that requires recursion or iteration,
i.e. the part involving only gradients of ep. This is:
ep_grad[b][s - 1][t] += ep_grad[b][s][t] * epx[b][s - 1][t]
ep_grad[b][s][t - 1] += ep_grad[b][s][t] * epy[b][s][t - 1].
.. and if we add 1 to the s index of the first equation above and 1 to the
t index of the second equation, we can see that:
ep_grad[b][s][t] = ep_grad[b][s + 1][t] * epx[b][s][t] +
ep_grad[b][s][t + 1] * epy[b][s][t].
Now, if ep = exp(p), and y is the loss function we are backprop'ing,
then ep_grad == dy/dep == dy/dp dp/dep == dy/dp / (dep/dp) == dy/dp / exp(p)
== dy/dp / ep. == p_grad / ep.
I.e. ep_grad = p_grad / ep.
So we can write the above as:
p_grad[b][s][t] / ep[b][s][t] = p_grad[b][s + 1][t] / ep[b][s + 1][t] * epx[b][s][t] +
p_grad[b][s][t + 1] / ep[b][s][t + 1] * epy[b][s][t].
Or, rearranging: p[b,s,t] = log_add(p[b,s-1,t+t_offset] + px[b,s-1,t+t_offset],
p_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) + p[b,s,t-1] + py[b,s,t-1]) (eq. 0)
p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]). (eq. 2)
(B) The following is the backprop for epx and epy from (eq. 1):
epx_grad[b][s - 1][t] += ep_grad[b][s][t] * ep[b][s - 1][t]
epy_grad[b][s][t - 1] += ep_grad[b][s][t] * ep[b][s][t - 1]
.. adding 1 to the s indexes in the 1st equation and to the t indexes in the 2nd:
epx_grad[b][s][t] = ep_grad[b][s + 1][t] * ep[b][s][t] where t_offset = (modified ? -1 : 0)
epy_grad[b][s][t] = ep_grad[b][s][t + 1] * ep[b][s][t]
Using, similar to the above, ep_grad = p_grad / ep, and similarly, The backprop for the above, implemented in the obvious way, would be as
epx_grad = px_grad / epx and epy_grad = py_grad / epy, and writing exp(p) for p and so on, follows (note, we define term1 and term2 with offsets in the indexes, which
the above becomes: will be convenient later..):
px_grad[b][s][t] / exp(px[b][s][t]) = p_grad[b][s + 1][t] / exp(p[b][s + 1][t]) * exp(p[b][s][t]) term1(b,s-1,t+t_offset) =
py_grad[b][s][t] / exp(py[b][s][t]) = p_grad[b][s][t + 1] / exp(p[b][s][t + 1]) * exp(p[b][s][t]) exp(p[b,s-1,t+t_offset] + px[b,s-1,t+t_offset] - p[b,s,t]) (0a)
Rearranging: term2(b,s,t-1) = exp(p[b,s,t-1] + py[b,s,t-1] - p[b,s,t]) (0b)
px_grad[b][s][t] = p_grad[b][s + 1][t] * exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 3a)
py_grad[b][s][t] = p_grad[b][s][t + 1] * exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 3b)
p_grad[b,s-1,t+t_offset] += p_grad[b,s,t] * term1(b,s-1,t+t_offset) (1a)
px_grad[b,s-1,t+t_offset] += p_grad[b,s,t] * term1(b,s-1,t+t_offset) (1b)
p_grad[b,s,t-1] += p_grad[b,s,t] * term2(b,s,t-1) (1c)
py_grad[b,s,t-1] += p_grad[b,s,t] * term2(b,s,t-1) (1d)
Defining terms that are common to (eq. 2) and (eqs. 3a,3b), write: Adding 1 and -t_offset to the s and t indexes of (1a) an (1b), and
1 to the t index of (1c) and (1d), the equations become:
xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) (eq. 4) p_grad[b,s,t] += p_grad[b,s+1,t-t_offset] * term1(b,s,t) (2a)
yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) (eq. 5) px_grad[b,s,t] += p_grad[b,s+1,t-t_offset] * term1(b,s,t) (2b)
p_grad[b,s,t] += p_grad[b,s,t+1] * term2(b,s,t) (2c)
py_grad[b,s,t] += p_grad[b,s,t+1] * term2(b,s,t) (2d)
.. and note that these quantities are <= 1 so there is no problem doing .. and replacing "+=" with "=", we can write:
the exponentiation. So the recursion can be simplified as from eqs. (2, 3a, 3b), as:
p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + p_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t) + (3a)
p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 6) p_grad[b,s,t+1] * term2(b,s,t)
px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] (eq. 7) px_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t) (3b)
py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] (eq. 8) py_grad[b,s,t] = p_grad[b,s,t+1] * term2(b,s,t) (3c)
(It might seem like we could just reuse px_grad and py_grad for (eq. 6), but it's Writing the definitions of term1 and term2 in a more convenient way:
not clear to me that this is the best strategy since that would require an extra term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a)
write to shared memory within the loop that's the limiting factor.) term2(b,s,t) = exp(p[b,s,t] + py[b,s,t] - p[b,s,t+1]) (4b)
The backward pass will be slightly different from the forward pass in terms of The backward pass will be slightly different from the forward pass in terms of
how we store and index p (and p_grad), because for writing a particular block how we store and index p (and p_grad), because for writing a particular block
of p_grad, we need context on the top and right instead of the bottom and of p_grad, we need context on the top and right instead of the bottom and
left. So there are offsets of 1. left. So there are offsets of 1.
*/ */
template <typename scalar_t, template <typename scalar_t, int BLOCK_SIZE>
int BLOCK_SIZE> __global__ void mutual_information_backward_kernel(
__global__ torch::PackedTensorAccessor32<scalar_t, 3>
void mutual_information_backward_kernel( px, // B, S, T + 1 if !modified; B, S, T if modified.
torch::PackedTensorAccessor32<scalar_t, 3> px, // B, S, T + 1, i.e. batch, x_seq_length, y_seq_length + 1 torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T.
torch::PackedTensorAccessor32<scalar_t, 3> py, // B, S + 1, T. // B, S + 1, T + 1. Produced in forward pass.
torch::PackedTensorAccessor32<scalar_t, 3> p, // B, S + 1, T + 1. Produced in forward pass. torch::PackedTensorAccessor32<scalar_t, 3> p,
torch::PackedTensorAccessor32<scalar_t, 1> ans_grad, // [B]. This is an input. // [B]. This is an input.
torch::PackedTensorAccessor32<scalar_t, 3> p_grad, // B, S + 1, T + 1. This is a temporary. torch::PackedTensorAccessor32<scalar_t, 1> ans_grad,
torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1. torch::PackedTensorAccessor32<scalar_t, 3>
torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T. p_grad, // B, S + 1, T + 1 if !modified; B, S, T if modified.
torch::PackedTensorAccessor32<int64_t, 2> boundary, // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T) torch::PackedTensorAccessor32<scalar_t, 3> px_grad, // B, S, T + 1.
int iter, // This kernel is sequentially called with 'iter' = num_iters torch::PackedTensorAccessor32<scalar_t, 3> py_grad, // B, S + 1, T.
// - 1, num_iters - 2, .. 0, where num_iters can be taken to // B, 4; or 0, 0 if boundaries are the defaults (0, 0, S, T)
// be any sufficiently large number but will actually be: torch::PackedTensorAccessor32<int64_t, 2> boundary,
// num_s_blocks + num_t_blocks - 1 where num_s_blocks = S / int iter, // This kernel is sequentially called with 'iter' = num_iters
// BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1 // - 1, num_iters - 2, .. 0, where num_iters can be taken to
bool overwrite_ans_grad) { // If overwite_ans_grad == true, this function // be any sufficiently large number but will actually be:
// will overwrite ans_grad with a value which, // num_s_blocks + num_t_blocks - 1 where num_s_blocks = S /
// if everything is working correctly, should be // BLOCK_SIZE + 1 and num_t_blocks = T / BLOCK_SIZE + 1
// identical or very close to the value of bool overwrite_ans_grad) { // If overwite_ans_grad == true, this function
// ans_grad that was passed in. // will overwrite ans_grad with a value which,
const int B = px.size(0), // if everything is working correctly, should be
S = px.size(1), // identical or very close to the value of
T = py.size(2); // ans_grad that was passed in.
const int B = px.size(0), S = px.size(1), T = py.size(2);
const bool modified = (px.size(2) == T);
const int neg_t_offset = (modified ? 1 : 0);
// For statements that are the same as the forward pass, we are omitting some // For statements that are the same as the forward pass, we are omitting some
// comments. We'll focus, in the comments, on differences from the forward // comments. We'll focus, in the comments, on differences from the forward
// pass. // pass.
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1,
// num_t_blocks = T / BLOCK_SIZE + 1, // num_t_blocks = T / BLOCK_SIZE + 1,
num_blocks_this_iter = min(iter + 1, num_s_blocks); num_blocks_this_iter = min(iter + 1, num_s_blocks);
// px_buf and py_buf are used temporarily to store the px and py values, // px_buf and py_buf are used temporarily to store the px and py values,
// but then modified to store the "xderiv" and "yderiv" values defined // but then modified to store the "xderiv" and "yderiv" values defined
// in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0 // in (eq. 5) and (eq. 6) above. For out-of-range values, we'll write 0.0
...@@ -564,15 +442,17 @@ void mutual_information_backward_kernel( ...@@ -564,15 +442,17 @@ void mutual_information_backward_kernel(
// px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin]; // px_buf[s][t] contains px[s+s_block_begin][t+t_block_begin];
// py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin]. // py_buf[s][t] contains py[s+s_block_begin][t+t_block_begin].
// Later (see eq. 4 and eq. 5): // Later (see eq. 4 and eq. 5):
// px_buf[s][t] contains exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt]), // px_buf[s][t] contains term1(b,ss,tt) ==
// py_buf[s][t] contains exp(p[b][ss][tt] + py[b][ss][tt] - p[b][ss][tt + 1] // exp(p[b][ss][tt] + px[b][ss][tt] - p[b][ss + 1][tt-t_offset]),
// py_buf[s][t] contains term2(b,ss,tt) ==
// where ss == s + s_block_begin, tt = t + t_block_begin. // where ss == s + s_block_begin, tt = t + t_block_begin.
// Unlike in the forward code, there is no offset of 1 in the indexes. // Unlike in the forward code, there is no offset of 1 in the indexes.
__shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE], __shared__ scalar_t px_buf[BLOCK_SIZE][BLOCK_SIZE],
py_buf[BLOCK_SIZE][BLOCK_SIZE]; py_buf[BLOCK_SIZE][BLOCK_SIZE];
// p_buf is initially used to store p, and then (after we are done putting // p_buf is initially used to store p, and then (after we are done putting
// xderiv and yderiv into px_buf and py_buf) it is repurposed to store // term1 and term2 into px_buf and py_buf) it is repurposed to store
// p_grad. // p_grad.
// //
// Unlike in the forward pass, p_buf has the same numbering as px_buf and // Unlike in the forward pass, p_buf has the same numbering as px_buf and
...@@ -603,19 +483,16 @@ void mutual_information_backward_kernel( ...@@ -603,19 +483,16 @@ void mutual_information_backward_kernel(
for (int batch_block_iter = blockIdx.x; for (int batch_block_iter = blockIdx.x;
batch_block_iter < B * num_blocks_this_iter; batch_block_iter < B * num_blocks_this_iter;
batch_block_iter += gridDim.x) { batch_block_iter += gridDim.x) {
int block = batch_block_iter / B, int block = batch_block_iter / B, b = batch_block_iter % B;
b = batch_block_iter % B;
int s_block_begin = block * BLOCK_SIZE, int s_block_begin = block * BLOCK_SIZE,
t_block_begin = (iter - block) * BLOCK_SIZE; t_block_begin = (iter - block) * BLOCK_SIZE;
if (threadIdx.x < 4 && boundary.size(0) != 0) if (threadIdx.x < 4)
boundary_buf[threadIdx.x] = boundary[b][threadIdx.x]; boundary_buf[threadIdx.x] = boundary[b][threadIdx.x];
__syncthreads(); __syncthreads();
int s_begin = boundary_buf[0], int s_begin = boundary_buf[0], t_begin = boundary_buf[1],
t_begin = boundary_buf[1], s_end = boundary_buf[2], t_end = boundary_buf[3];
s_end = boundary_buf[2],
t_end = boundary_buf[3];
s_block_begin += s_begin; s_block_begin += s_begin;
t_block_begin += t_begin; t_block_begin += t_begin;
...@@ -633,13 +510,11 @@ void mutual_information_backward_kernel( ...@@ -633,13 +510,11 @@ void mutual_information_backward_kernel(
// Load px_buf and py_buf. At this point we just set them to the px and py // Load px_buf and py_buf. At this point we just set them to the px and py
// for this block. // for this block.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
s = s_in_block + s_block_begin, // We let px and py default to -infinity if they are out of range, which
t = t_in_block + t_block_begin; // will cause xderiv and yderiv for out-of-range values to be zero, and
// We let px and py default to -infinity if they are out of range, which will // cause correct behavior in edge cases (for the top and right blocks).
// cause xderiv and yderiv for out-of-range values to be zero, and cause
// correct behavior in edge cases (for the top and right blocks).
// The issue is that p and p_grad are of larger size than px and py. // The issue is that p and p_grad are of larger size than px and py.
scalar_t this_px = -INFINITY; scalar_t this_px = -INFINITY;
if (s < s_end && t <= t_end) if (s < s_end && t <= t_end)
...@@ -653,11 +528,10 @@ void mutual_information_backward_kernel( ...@@ -653,11 +528,10 @@ void mutual_information_backward_kernel(
__syncthreads(); __syncthreads();
// load p. // load p.
for (int i = threadIdx.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1); i += blockDim.x) { for (int i = threadIdx.x; i < (BLOCK_SIZE + 1) * (BLOCK_SIZE + 1);
int s_in_block = i / (BLOCK_SIZE + 1), i += blockDim.x) {
t_in_block = i % (BLOCK_SIZE + 1), int s_in_block = i / (BLOCK_SIZE + 1), t_in_block = i % (BLOCK_SIZE + 1),
s = s_in_block + s_block_begin, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
t = t_in_block + t_block_begin;
// Setting 0.0 for out-of-bounds elements of p, together with setting // Setting 0.0 for out-of-bounds elements of p, together with setting
// -INFINITY for out-of-bounds elements of px_buf and py_buf, will // -INFINITY for out-of-bounds elements of px_buf and py_buf, will
// ensure that we do the right thing in top and right edge cases, // ensure that we do the right thing in top and right edge cases,
...@@ -666,56 +540,57 @@ void mutual_information_backward_kernel( ...@@ -666,56 +540,57 @@ void mutual_information_backward_kernel(
scalar_t this_p = 0.0; scalar_t this_p = 0.0;
if (s <= s_end && t <= t_end) if (s <= s_end && t <= t_end)
this_p = p[b][s][t]; this_p = p[b][s][t];
// if this_p is -inf, replace with large finite negative value, to avoid
// NaN's below.
// TODO: use a value that would work correctly in half precision
if (this_p < -1.0e+30)
this_p = -1.0e+30;
p_buf[s_in_block][t_in_block] = this_p; p_buf[s_in_block][t_in_block] = this_p;
} }
__syncthreads(); __syncthreads();
// Set xderiv and yderiv; see (eq. 4) and (eq. 5). // Set term1 and term2; see equations (4a) and (4b) above.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
// We can apply this formula to the entire block even if we are processing // We can apply this formula to the entire block even if we are processing
// a partial block; we have ensured that x_buf and y_buf contain -infinity, // a partial block; we have ensured that x_buf and y_buf contain
// and p contains 0, for out-of-range elements, so we'll get x_buf and y_buf // -infinity, and p contains 0, for out-of-range elements, so we'll get
// containing 0 after applying the followin formulas. // x_buf and y_buf containing 0 after applying the followin formulas.
int s = i / BLOCK_SIZE, int s = i / BLOCK_SIZE, t = i % BLOCK_SIZE;
t = i % BLOCK_SIZE;
// Mathematically the following is doing: // Mathematically the following is doing:
// xderiv[b][s][t] := exp(p[b][s][t] + px[b][s][t] - p[b][s + 1][t]) // term1(b,s,t) = exp(p[b,s,t] + px[b,s,t] - p[b,s+1,t-t_offset]) (4a)
// (with an offset on the s and t indexes) // (with an offset on the s and t indexes)
px_buf[s][t] = exp(p_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t]); // Use safe_exp() not exp(), as we could have (-inf) - (-inf) = nan, want
// any finite number in this case as derivs would be zero.
// Also want -inf->zero.
px_buf[s][t] =
safe_exp(p_buf[s][t] + px_buf[s][t] - p_buf[s + 1][t + neg_t_offset]);
// Mathematically the following is doing: // Mathematically the following is doing:
// yderiv[b][s][t] := exp(p[b][s][t] + py[b][s][t] - p[b][s][t + 1]) // term2(b,s,t) = exp(p[b,s,t] + py[b,s,t] - p[b,s,t+1]) (4b)
// (with an offset on the s and t indexes) // (with an offset on the s and t indexes)
py_buf[s][t] = exp(p_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]); py_buf[s][t] = safe_exp(p_buf[s][t] + py_buf[s][t] - p_buf[s][t + 1]);
} }
__syncthreads(); __syncthreads();
// Load p_grad for the top and right elements in p_buf: i.e. for elements // Load p_grad for the top and right elements in p_buf: i.e. for elements
// p_buf[s][t] where s == block_S (exclusive-or) t == block_T. We don't // p_buf[s][t] where s == block_S (exclusive-or) t == block_T.
// need to load the top-right corner [block_S][block_T]; that location will
// never be accessed.
// These are the p_grad values computed by previous instances of this kernel // These are the p_grad values computed by previous instances of this kernel
// If this is one of the top or right blocks, some or all of the p_grad // If this is one of the top or right blocks, some or all of the p_grad
// values we'd be reading here will be out of range, and we use zeros // values we'd be reading here will be out of range, and we use zeros
// to ensure no gradient gets propagated from those positions. // to ensure no gradient gets propagated from those positions.
if (threadIdx.x < block_S) { if (threadIdx.x <= block_S) {
int s_in_block = threadIdx.x, int s_in_block = threadIdx.x, t_in_block = block_T,
t_in_block = block_T, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
s = s_in_block + s_block_begin, p_buf[s_in_block][t_in_block] =
t = t_in_block + t_block_begin; (s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
p_buf[s_in_block][t_in_block] = ( } else if (static_cast<unsigned int>(static_cast<int>(threadIdx.x) - 64) <
s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
} else if (static_cast<unsigned int>((int)threadIdx.x - 64) <
static_cast<unsigned int>(block_T)) { static_cast<unsigned int>(block_T)) {
// casting to unsigned before the comparison tests for both negative and // casting to unsigned before the comparison tests for both negative and
// out-of-range values of (int)threadIdx.x - 64. // out-of-range values of (int)threadIdx.x - 64.
int s_in_block = block_S, int s_in_block = block_S, t_in_block = static_cast<int>(threadIdx.x) - 64,
t_in_block = (int)threadIdx.x - 64, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
s = s_in_block + s_block_begin, p_buf[s_in_block][t_in_block] =
t = t_in_block + t_block_begin; (s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
p_buf[s_in_block][t_in_block] = (
s <= s_end && t <= t_end ? p_grad[b][s][t] : 0.0);
} }
__syncthreads(); __syncthreads();
...@@ -748,10 +623,11 @@ void mutual_information_backward_kernel( ...@@ -748,10 +623,11 @@ void mutual_information_backward_kernel(
static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) { static_cast<unsigned int>(t) < static_cast<unsigned int>(block_T)) {
// The following statement is really operating on the gradients; // The following statement is really operating on the gradients;
// it corresponds, with offsets of s_block_begin and t_block_begin // it corresponds, with offsets of s_block_begin and t_block_begin
// on the indexes, to (eq. 6) defined above, i.e.: // on the indexes, to equation (3a) above, i.e.:
// p_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] + // p_grad[b,s,t] =
// p_grad[b][s][t + 1] * yderiv[b][s][t] // p_grad[b,s+1,t-t_offset] * term1(b,s,t) + (3a)
p_buf[s][t] = (p_buf[s + 1][t] * px_buf[s][t] + // p_grad[b,s,t+1] * term2(b,s,t)
p_buf[s][t] = (p_buf[s + 1][t + neg_t_offset] * px_buf[s][t] +
p_buf[s][t + 1] * py_buf[s][t]); p_buf[s][t + 1] * py_buf[s][t]);
} }
} }
...@@ -761,24 +637,27 @@ void mutual_information_backward_kernel( ...@@ -761,24 +637,27 @@ void mutual_information_backward_kernel(
// Write out p_grad, px_grad and py_grad. // Write out p_grad, px_grad and py_grad.
for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) { for (int i = threadIdx.x; i < BLOCK_SIZE * BLOCK_SIZE; i += blockDim.x) {
int s_in_block = i / BLOCK_SIZE, int s_in_block = i / BLOCK_SIZE, t_in_block = i % BLOCK_SIZE,
t_in_block = i % BLOCK_SIZE, s = s_in_block + s_block_begin, t = t_in_block + t_block_begin;
s = s_in_block + s_block_begin,
t = t_in_block + t_block_begin;
// s_end and t_end are the one-past-the-end of the (x,y) sequences, but // s_end and t_end are the one-past-the-end of the (x,y) sequences, but
// the one-past-the-end element of p_grad would be (s_end + 1, t_end + 1). // the one-past-the-end element of p_grad would be (s_end + 1, t_end + 1).
if (t <= t_end && s <= s_end) { if (t <= t_end && s <= s_end) {
p_grad[b][s][t] = p_buf[s_in_block][t_in_block]; p_grad[b][s][t] = p_buf[s_in_block][t_in_block];
if (s < s_end) { // write px_grad, which is of shape [B][S][T + 1] if (s < s_end && t <= t_end - neg_t_offset) {
// From (eq. 7): // write px_grad, which is of shape [B][S][T + 1] if !modified,
// px_grad[b][s][t] = p_grad[b][s + 1][t] * xderiv[b][s][t] // [B][S][T] if modified. the condition "t <= t_end - neg_t_offset"
px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block] * // becomes "t <= t_end" if !modified, and "t <= t_end - 1" if
// modified, keeping us within the bounds of px_grad.
// From (eq. 3b):
// px_grad[b,s,t] = p_grad[b,s+1,t-t_offset] * term1(b,s,t)
px_grad[b][s][t] = (p_buf[s_in_block + 1][t_in_block + neg_t_offset] *
px_buf[s_in_block][t_in_block]); px_buf[s_in_block][t_in_block]);
} }
if (t < t_end) { // write py_grad, which is of shape [B][S + 1][T] if (t < t_end) { // write py_grad, which is of shape [B][S + 1][T]
// from (eq. 8): // from (eq. 3c):
// py_grad[b][s][t] = p_grad[b][s][t + 1] * yderiv[b][s][t] // py_grad[b,s,t] = p_grad[b,s,t+1] * term2(b,s,t)
py_grad[b][s][t] = (p_buf[s_in_block][t_in_block + 1] * py_grad[b][s][t] = (p_buf[s_in_block][t_in_block + 1] *
py_buf[s_in_block][t_in_block]); py_buf[s_in_block][t_in_block]);
} }
...@@ -791,81 +670,77 @@ void mutual_information_backward_kernel( ...@@ -791,81 +670,77 @@ void mutual_information_backward_kernel(
} }
} }
// forward of mutual_information. See """... """ comment of
// `mutual_information` in mutual_information.py for documentation of the
// forward of mutual_information. See """... """ comment of `mutual_information` in // behavior of this function.
// mutual_information.py for documentation of the behavior of this function. torch::Tensor MutualInformationCuda(torch::Tensor px, torch::Tensor py,
torch::Tensor mutual_information_cuda(torch::Tensor px, torch::optional<torch::Tensor> opt_boundary,
torch::Tensor py, torch::Tensor p) {
torch::Tensor boundary,
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() && p.device().is_cuda(), TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() &&
p.device().is_cuda(),
"inputs must be CUDA tensors"); "inputs must be CUDA tensors");
auto scalar_t = px.scalar_type(); auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), const int B = px.size(0), S = px.size(1), T = py.size(2);
S = px.size(1), TORCH_CHECK(px.size(2) == T || px.size(2) == T + 1);
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4)); auto boundary = opt_boundary.value_or(
TORCH_CHECK(boundary.device().is_cuda() && torch::tensor({0, 0, S, T},
boundary.dtype() == torch::kInt64); torch::dtype(torch::kInt64).device(px.device()))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cuda() && boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts); torch::Tensor ans = torch::empty({B}, opts);
// num_threads and num_blocks and BLOCK_SIZE can be tuned. // num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128). // (however, num_threads may not be less than 128).
const int num_threads = 128, const int num_threads = 128, num_blocks = 256, BLOCK_SIZE = 32;
num_blocks = 256,
BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1), // The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g. // so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1 // (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1; num_iters = num_s_blocks + num_t_blocks - 1;
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_cuda_stub", ([&] { AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cuda_stub", ([&] {
for (int iter = 0; iter < num_iters; ++iter) { for (int iter = 0; iter < num_iters; ++iter) {
mutual_information_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>( mutual_information_kernel<scalar_t, BLOCK_SIZE>
px.packed_accessor32<scalar_t, 3>(), <<<num_blocks, num_threads>>>(
py.packed_accessor32<scalar_t, 3>(), px.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(), py.packed_accessor32<scalar_t, 3>(),
boundary.packed_accessor32<int64_t, 2>(), p.packed_accessor32<scalar_t, 3>(),
ans.packed_accessor32<scalar_t, 1>(), boundary.packed_accessor32<int64_t, 2>(),
iter); ans.packed_accessor32<scalar_t, 1>(), iter);
} }
})); }));
return ans; return ans;
} }
// backward of mutual_information; returns (grad_px, grad_py) // backward of mutual_information; returns (grad_px, grad_py)
// If overwrite_ans_grad == true, will overwrite ans_grad with a value which // If overwrite_ans_grad == true, will overwrite ans_grad with a value which
// should be identical to the original ans_grad if the computation worked // should be identical to the original ans_grad if the computation worked
// as it should. // as it should.
std::vector<torch::Tensor> std::vector<torch::Tensor>
mutual_information_backward_cuda(torch::Tensor px, MutualInformationBackwardCuda(torch::Tensor px, torch::Tensor py,
torch::Tensor py, torch::optional<torch::Tensor> opt_boundary,
torch::Tensor boundary, torch::Tensor p, torch::Tensor ans_grad,
torch::Tensor p, bool overwrite_ans_grad) {
torch::Tensor ans_grad,
bool overwrite_ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional"); TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional."); TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional."); TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional."); TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional.");
TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() && TORCH_CHECK(px.device().is_cuda() && py.device().is_cuda() &&
p.device().is_cuda() && ans_grad.device().is_cuda() && p.device().is_cuda() && ans_grad.device().is_cuda() &&
"inputs must be CUDA tensors"); "inputs must be CUDA tensors");
...@@ -873,55 +748,59 @@ mutual_information_backward_cuda(torch::Tensor px, ...@@ -873,55 +748,59 @@ mutual_information_backward_cuda(torch::Tensor px,
auto scalar_t = px.scalar_type(); auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device()); auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), const int B = px.size(0), S = px.size(1), T = py.size(2);
S = px.size(1),
T = px.size(2) - 1;
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T); TORCH_CHECK(px.size(2) == T ||
px.size(2) == T + 1); // modified case || not-modified case
const bool modified = (px.size(2) == T);
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1); TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
TORCH_CHECK((boundary.size(0) == 0 && boundary.size(1) == 0) ||
(boundary.size(0) == B && boundary.size(1) == 4)); auto boundary = opt_boundary.value_or(
TORCH_CHECK(boundary.device().is_cuda() && torch::tensor({0, 0, S, T},
boundary.dtype() == torch::kInt64); torch::dtype(torch::kInt64).device(px.device()))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cuda() && boundary.dtype() == torch::kInt64);
TORCH_CHECK(ans_grad.size(0) == B); TORCH_CHECK(ans_grad.size(0) == B);
bool has_boundary = (boundary.size(0) != 0); bool has_boundary = opt_boundary.has_value();
int T1 = T + (modified ? 0 : 1);
torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts), torch::Tensor p_grad = torch::empty({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T + 1}, opts) : px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts)
torch::empty({B, S, T + 1}, opts)), : torch::empty({B, S, T1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts) : py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts)
torch::empty({B, S + 1, T}, opts)); : torch::empty({B, S + 1, T}, opts));
// num_threads and num_blocks and BLOCK_SIZE can be tuned. // num_threads and num_blocks and BLOCK_SIZE can be tuned.
// (however, num_threads may not be less than 128). // (however, num_threads may not be less than 128).
const int num_threads = 128, const int num_threads = 128, num_blocks = 256, BLOCK_SIZE = 32;
num_blocks = 256,
BLOCK_SIZE = 32;
// The blocks cover the 'p' matrix, which is of size (B, S+1, T+1), // The blocks cover the 'p' matrix, which is of size (B, S+1, T+1),
// so dividing by BLOCK_SIZE rounding up we get e.g. // so dividing by BLOCK_SIZE rounding up we get e.g.
// (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1 // (S+1 + BLOCK_SIZE-1) / BLOCK_SIZE == S / BLOCK_SIZE + 1
const int num_s_blocks = S / BLOCK_SIZE + 1, const int num_s_blocks = S / BLOCK_SIZE + 1,
num_t_blocks = T / BLOCK_SIZE + 1, num_t_blocks = T / BLOCK_SIZE + 1,
num_iters = num_s_blocks + num_t_blocks - 1; num_iters = num_s_blocks + num_t_blocks - 1;
AT_DISPATCH_FLOATING_TYPES(px.scalar_type(), "mutual_information_backward_stub", ([&] { AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_backward_stub", ([&] {
for (int iter = num_iters - 1; iter >= 0; --iter) { for (int iter = num_iters - 1; iter >= 0; --iter) {
mutual_information_backward_kernel<scalar_t, BLOCK_SIZE><<<num_blocks, num_threads>>>( mutual_information_backward_kernel<scalar_t, BLOCK_SIZE>
px.packed_accessor32<scalar_t, 3>(), <<<num_blocks, num_threads>>>(
py.packed_accessor32<scalar_t, 3>(), px.packed_accessor32<scalar_t, 3>(),
p.packed_accessor32<scalar_t, 3>(), py.packed_accessor32<scalar_t, 3>(),
ans_grad.packed_accessor32<scalar_t, 1>(), p.packed_accessor32<scalar_t, 3>(),
p_grad.packed_accessor32<scalar_t, 3>(), ans_grad.packed_accessor32<scalar_t, 1>(),
px_grad.packed_accessor32<scalar_t, 3>(), p_grad.packed_accessor32<scalar_t, 3>(),
py_grad.packed_accessor32<scalar_t, 3>(), px_grad.packed_accessor32<scalar_t, 3>(),
boundary.packed_accessor32<int64_t, 2>(), py_grad.packed_accessor32<scalar_t, 3>(),
iter, boundary.packed_accessor32<int64_t, 2>(), iter,
overwrite_ans_grad); overwrite_ans_grad);
} }
})); }));
// std::cout << "p_grad = " << p_grad;
return std::vector<torch::Tensor>({px_grad, py_grad}); return std::vector<torch::Tensor>({px_grad, py_grad});
} }
} // namespace fast_rnnt
add_subdirectory(csrc)
add_subdirectory(tests)
include_directories(${CMAKE_SOURCE_DIR})
pybind11_add_module(_fast_rnnt
mutual_information.cu
)
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)
if(UNIX AND NOT APPLE)
target_link_libraries(_fast_rnnt
PRIVATE
${PYTHON_LIBRARY}
${TORCH_DIR}/lib/libtorch_python.so
)
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment