Commit d53e923b authored by pkufool's avatar pkufool
Browse files

Move k2 rnnt_loss here

parent b5828e2b
......@@ -2,3 +2,4 @@
.idea
venv*
deploy*
__pycache__/*
if("x${CMAKE_SOURCE_DIR}" STREQUAL "x${CMAKE_BINARY_DIR}")
message(FATAL_ERROR "\
In-source build is not a good practice.
Please use:
mkdir build
cd build
cmake ..
to build this project"
)
endif()
cmake_minimum_required(VERSION 3.8 FATAL_ERROR)
set(languages CXX)
set(_FT_WITH_CUDA ON)
set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_program(FT_HAS_NVCC nvcc)
if(NOT FT_HAS_NVCC AND "$ENV{CUDACXX}" STREQUAL "")
message(STATUS "No NVCC detected. Disable CUDA support")
set(_FT_WITH_CUDA OFF)
endif()
if(APPLE OR (DEFINED FT_WITH_CUDA AND NOT FT_WITH_CUDA))
if(_FT_WITH_CUDA)
message(STATUS "Disable CUDA support")
set(_FT_WITH_CUDA OFF)
endif()
endif()
if(_FT_WITH_CUDA)
set(languages ${languages} CUDA)
if(NOT DEFINED FT_WITH_CUDA)
set(FT_WITH_CUDA ON)
endif()
endif()
message(STATUS "Enabled languages: ${languages}")
project(fast_rnnt ${languages})
set(FT_VERSION "1.0")
set(ALLOWABLE_BUILD_TYPES Debug Release RelWithDebInfo MinSizeRel)
set(DEFAULT_BUILD_TYPE "Release")
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "${ALLOWABLE_BUILD_TYPES}")
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
# CMAKE_CONFIGURATION_TYPES: with config type values from other generators (IDE).
message(STATUS "No CMAKE_BUILD_TYPE given, default to ${DEFAULT_BUILD_TYPE}")
set(CMAKE_BUILD_TYPE "${DEFAULT_BUILD_TYPE}")
elseif(NOT CMAKE_BUILD_TYPE IN_LIST ALLOWABLE_BUILD_TYPES)
message(FATAL_ERROR "Invalid build type: ${CMAKE_BUILD_TYPE}, \
choose one from ${ALLOWABLE_BUILD_TYPES}")
endif()
option(FT_BUILD_TESTS "Whether to build tests or not" OFF)
option(BUILD_SHARED_LIBS "Whether to build shared libs" ON)
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin")
set(CMAKE_SKIP_BUILD_RPATH FALSE)
set(BUILD_RPATH_USE_ORIGIN TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
set(CMAKE_INSTALL_RPATH "$ORIGIN")
set(CMAKE_BUILD_RPATH "$ORIGIN")
if(FT_WITH_CUDA)
add_definitions(-DFT_WITH_CUDA)
# Force CUDA C++ standard to be the same as the C++ standard used.
#
# Now, CMake is unaligned with reality on standard versions: https://gitlab.kitware.com/cmake/cmake/issues/18597
# which means that using standard CMake methods, it's impossible to actually sync the CXX and CUDA versions for pre-11
# versions of C++; CUDA accepts 98 but translates that to 03, while CXX doesn't accept 03 (and doesn't translate that to 03).
# In case this gives You, dear user, any trouble, please escalate the above CMake bug, so we can support reality properly.
if(DEFINED CMAKE_CUDA_STANDARD)
message(WARNING "You've set CMAKE_CUDA_STANDARD; please note that this variable is ignored, and CMAKE_CXX_STANDARD"
" is used as the C++ standard version for both C++ and CUDA.")
endif()
unset(CMAKE_CUDA_STANDARD CACHE)
set(CMAKE_CUDA_STANDARD ${CMAKE_CXX_STANDARD})
include(cmake/select_compute_arch.cmake)
cuda_select_nvcc_arch_flags(FT_COMPUTE_ARCH_FLAGS)
message(STATUS "FT_COMPUTE_ARCH_FLAGS: ${FT_COMPUTE_ARCH_FLAGS}")
# set(OT_COMPUTE_ARCHS 30 32 35 50 52 53 60 61 62 70 72)
# message(WARNING "arch 62/72 are not supported for now")
# see https://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/
# https://www.myzhar.com/blog/tutorials/tutorial-nvidia-gpu-cuda-compute-capability/
set(FT_COMPUTE_ARCH_CANDIDATES 35 50 60 61 70 75)
if(CUDA_VERSION VERSION_GREATER "11.0")
list(APPEND FT_COMPUTE_ARCH_CANDIDATES 80 86)
endif()
message(STATUS "FT_COMPUTE_ARCH_CANDIDATES ${FT_COMPUTE_ARCH_CANDIDATES}")
set(FT_COMPUTE_ARCHS)
foreach(COMPUTE_ARCH IN LISTS FT_COMPUTE_ARCH_CANDIDATES)
if("${FT_COMPUTE_ARCH_FLAGS}" MATCHES ${COMPUTE_ARCH})
message(STATUS "Adding arch ${COMPUTE_ARCH}")
list(APPEND FT_COMPUTE_ARCHS ${COMPUTE_ARCH})
else()
message(STATUS "Skipping arch ${COMPUTE_ARCH}")
endif()
endforeach()
if(NOT FT_COMPUTE_ARCHS)
set(FT_COMPUTE_ARCHS ${FT_COMPUTE_ARCH_CANDIDATES})
endif()
message(STATUS "FT_COMPUTE_ARCHS: ${FT_COMPUTE_ARCHS}")
foreach(COMPUTE_ARCH IN LISTS FT_COMPUTE_ARCHS)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-extended-lambda -gencode arch=compute_${COMPUTE_ARCH},code=sm_${COMPUTE_ARCH}")
set(CMAKE_CUDA_ARCHITECTURES "${COMPUTE_ARCH}-real;${COMPUTE_ARCH}-virtual;${CMAKE_CUDA_ARCHITECTURES}")
endforeach()
endif()
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake)
include(pybind11)
include(torch)
if(FT_BUILD_TESTS)
enable_testing()
include(googletest)
endif()
add_subdirectory(fast_rnnt)
This diff is collapsed.
# Distributed under the OSI-approved BSD 3-Clause License. See accompanying
# file Copyright.txt or https://cmake.org/licensing for details.
cmake_minimum_required(VERSION ${CMAKE_VERSION})
# We name the project and the target for the ExternalProject_Add() call
# to something that will highlight to the user what we are working on if
# something goes wrong and an error message is produced.
project(${contentName}-populate NONE)
include(ExternalProject)
ExternalProject_Add(${contentName}-populate
${ARG_EXTRA}
SOURCE_DIR "${ARG_SOURCE_DIR}"
BINARY_DIR "${ARG_BINARY_DIR}"
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
# Try to find Valgrind headers and libraries.
#
# Usage of this module as follows:
# find_package(Valgrind)
#
# Variables used by this module, they can change the default behaviour and need
# to be set before calling find_package:
#
# VALGRIND_ROOT Set this variable to the root installation of valgrind if the
# module has problems finding the proper installation path.
#
# Variables defined by this module:
# Valgrind_FOUND System has valgrind
# Valgrind_INCLUDE_DIR where to find valgrind/memcheck.h, etc.
# Valgrind_EXECUTABLE the valgrind executable.
# Get hint from environment variable (if any)
if(NOT VALGRIND_ROOT AND DEFINED ENV{VALGRIND_ROOT})
set(VALGRIND_ROOT "$ENV{VALGRIND_ROOT}" CACHE PATH "Valgrind base directory location (optional, used for nonstandard installation paths)")
mark_as_advanced(VALGRIND_ROOT)
endif()
# Search path for nonstandard locations
if(VALGRIND_ROOT)
set(Valgrind_INCLUDE_PATH PATHS "${VALGRIND_ROOT}/include" NO_DEFAULT_PATH)
set(Valgrind_BINARY_PATH PATHS "${VALGRIND_ROOT}/bin" NO_DEFAULT_PATH)
endif()
find_path(Valgrind_INCLUDE_DIR valgrind HINTS ${Valgrind_INCLUDE_PATH})
find_program(Valgrind_EXECUTABLE NAMES valgrind PATH ${Valgrind_BINARY_PATH})
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Valgrind DEFAULT_MSG Valgrind_INCLUDE_DIR Valgrind_EXECUTABLE)
mark_as_advanced(Valgrind_INCLUDE_DIR Valgrind_EXECUTABLE)
if(NOT Valgrind_FOUND)
if(Valgrind_FIND_REQUIRED)
message(FATAL_ERROR "Valgrind required but it seems it has not be installed.")
endif()
else()
message(STATUS "Found Valgrind: ${Valgrind_EXECUTABLE}")
endif()
## FetchContent
`FetchContent.cmake` and `FetchContent/CMakeLists.cmake.in`
are copied from `cmake/3.11.0/share/cmake-3.11/Modules`.
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(download_googltest)
if(CMAKE_VERSION VERSION_LESS 3.11)
# FetchContent is available since 3.11,
# we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules
# so that it can be used in lower CMake versions.
message(STATUS "Use FetchContent provided by k2")
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()
include(FetchContent)
set(googletest_URL "https://github.com/google/googletest/archive/release-1.10.0.tar.gz")
set(googletest_HASH "SHA256=9dc9157a9a1551ec7a7e43daea9a694a0bb5fb8bec81235d8a1e6ef64c716dcb")
set(BUILD_GMOCK ON CACHE BOOL "" FORCE)
set(INSTALL_GTEST OFF CACHE BOOL "" FORCE)
set(gtest_disable_pthreads ON CACHE BOOL "" FORCE)
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_Declare(googletest
URL ${googletest_URL}
URL_HASH ${googletest_HASH}
)
FetchContent_GetProperties(googletest)
if(NOT googletest_POPULATED)
message(STATUS "Downloading googletest")
FetchContent_Populate(googletest)
endif()
message(STATUS "googletest is downloaded to ${googletest_SOURCE_DIR}")
message(STATUS "googletest's binary dir is ${googletest_BINARY_DIR}")
if(APPLE)
set(CMAKE_MACOSX_RPATH ON) # to solve the following warning on macOS
endif()
#[==[
-- Generating done
Policy CMP0042 is not set: MACOSX_RPATH is enabled by default. Run "cmake
--help-policy CMP0042" for policy details. Use the cmake_policy command to
set the policy and suppress this warning.
MACOSX_RPATH is not specified for the following targets:
gmock
gmock_main
gtest
gtest_main
This warning is for project developers. Use -Wno-dev to suppress it.
]==]
add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL)
target_include_directories(gtest
INTERFACE
${googletest_SOURCE_DIR}/googletest/include
${googletest_SOURCE_DIR}/googlemock/include
)
endfunction()
download_googltest()
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
# See ../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
function(download_pybind11)
if(CMAKE_VERSION VERSION_LESS 3.11)
list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules)
endif()
include(FetchContent)
set(pybind11_URL "https://github.com/pybind/pybind11/archive/v2.6.0.tar.gz")
set(pybind11_HASH "SHA256=90b705137b69ee3b5fc655eaca66d0dc9862ea1759226f7ccd3098425ae69571")
set(double_quotes "\"")
set(dollar "\$")
set(semicolon "\;")
if(NOT WIN32)
FetchContent_Declare(pybind11
URL ${pybind11_URL}
URL_HASH ${pybind11_HASH}
)
else()
FetchContent_Declare(pybind11
URL ${pybind11_URL}
URL_HASH ${pybind11_HASH}
)
endif()
FetchContent_GetProperties(pybind11)
if(NOT pybind11_POPULATED)
message(STATUS "Downloading pybind11")
FetchContent_Populate(pybind11)
endif()
message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}")
add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR} EXCLUDE_FROM_ALL)
endfunction()
download_pybind11()
#
# This file is copied from
# https://github.com/pytorch/pytorch/blob/master/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake
#
#
# Synopsis:
# CUDA_SELECT_NVCC_ARCH_FLAGS(out_variable [target_CUDA_architectures])
# -- Selects GPU arch flags for nvcc based on target_CUDA_architectures
# target_CUDA_architectures : Auto | Common | All | LIST(ARCH_AND_PTX ...)
# - "Auto" detects local machine GPU compute arch at runtime.
# - "Common" and "All" cover common and entire subsets of architectures
# ARCH_AND_PTX : NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX
# NAME: Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal Volta Turing Ampere
# NUM: Any number. Only those pairs are currently accepted by NVCC though:
# 3.5 3.7 5.0 5.2 5.3 6.0 6.2 7.0 7.2 7.5 8.0
# Returns LIST of flags to be added to CUDA_NVCC_FLAGS in ${out_variable}
# Additionally, sets ${out_variable}_readable to the resulting numeric list
# Example:
# CUDA_SELECT_NVCC_ARCH_FLAGS(ARCH_FLAGS 3.0 3.5+PTX 5.2(5.0) Maxwell)
# LIST(APPEND CUDA_NVCC_FLAGS ${ARCH_FLAGS})
#
# More info on CUDA architectures: https://en.wikipedia.org/wiki/CUDA
#
if(CMAKE_CUDA_COMPILER_LOADED OR DEFINED CMAKE_CUDA_COMPILER_ID) # CUDA as a language
if(CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA"
AND CMAKE_CUDA_COMPILER_VERSION MATCHES "^([0-9]+\\.[0-9]+)")
set(CUDA_VERSION "${CMAKE_MATCH_1}")
endif()
endif()
# See: https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#gpu-feature-list
# This list will be used for CUDA_ARCH_NAME = All option
set(CUDA_KNOWN_GPU_ARCHITECTURES "Kepler" "Maxwell")
# This list will be used for CUDA_ARCH_NAME = Common option (enabled by default)
set(CUDA_COMMON_GPU_ARCHITECTURES "3.5" "5.0")
if(CUDA_VERSION VERSION_LESS "7.0")
set(CUDA_LIMIT_GPU_ARCHITECTURE "5.2")
endif()
# This list is used to filter CUDA archs when autodetecting
set(CUDA_ALL_GPU_ARCHITECTURES "3.5" "5.0")
if(CUDA_VERSION VERSION_GREATER "6.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Kepler+Tegra" "Kepler+Tesla" "Maxwell+Tegra")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2")
if(CUDA_VERSION VERSION_LESS "8.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "5.2+PTX")
set(CUDA_LIMIT_GPU_ARCHITECTURE "6.0")
endif()
endif()
if(CUDA_VERSION VERSION_GREATER "7.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Pascal")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.0" "6.1")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "6.0" "6.1" "6.2")
if(CUDA_VERSION VERSION_LESS "9.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "6.2+PTX")
set(CUDA_LIMIT_GPU_ARCHITECTURE "7.0")
endif()
endif ()
if(CUDA_VERSION VERSION_GREATER "8.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Volta")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.0")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.0" "7.2")
if(CUDA_VERSION VERSION_LESS "10.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.2+PTX")
set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0")
endif()
endif()
if(CUDA_VERSION VERSION_GREATER "9.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Turing")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.5")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "7.5")
if(CUDA_VERSION VERSION_LESS "11.0")
set(CUDA_LIMIT_GPU_ARCHITECTURE "8.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "7.5+PTX")
endif()
endif()
if(CUDA_VERSION VERSION_GREATER "10.5")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ampere")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.0")
if(CUDA_VERSION VERSION_LESS "11.1")
set(CUDA_LIMIT_GPU_ARCHITECTURE "8.6")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.0+PTX")
endif()
endif()
if(CUDA_VERSION VERSION_GREATER "11.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6" "8.6+PTX")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6")
if(CUDA_VERSION VERSION_LESS "12.0")
set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0")
endif()
endif()
################################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled)
# Usage:
# CUDA_DETECT_INSTALLED_GPUS(OUT_VARIABLE)
#
function(CUDA_DETECT_INSTALLED_GPUS OUT_VARIABLE)
if(NOT CUDA_GPU_DETECT_OUTPUT)
if(CMAKE_CUDA_COMPILER_LOADED OR DEFINED CMAKE_CUDA_COMPILER_ID) # CUDA as a language
set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cu")
else()
set(file "${PROJECT_BINARY_DIR}/detect_cuda_compute_capabilities.cpp")
endif()
file(WRITE ${file} ""
"#include <cuda_runtime.h>\n"
"#include <cstdio>\n"
"int main()\n"
"{\n"
" int count = 0;\n"
" if (cudaSuccess != cudaGetDeviceCount(&count)) return -1;\n"
" if (count == 0) return -1;\n"
" for (int device = 0; device < count; ++device)\n"
" {\n"
" cudaDeviceProp prop;\n"
" if (cudaSuccess == cudaGetDeviceProperties(&prop, device))\n"
" std::printf(\"%d.%d \", prop.major, prop.minor);\n"
" }\n"
" return 0;\n"
"}\n")
if(CMAKE_CUDA_COMPILER_LOADED OR DEFINED CMAKE_CUDA_COMPILER_ID) # CUDA as a language
try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file}
RUN_OUTPUT_VARIABLE compute_capabilities)
else()
try_run(run_result compile_result ${PROJECT_BINARY_DIR} ${file}
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}"
LINK_LIBRARIES ${CUDA_LIBRARIES}
RUN_OUTPUT_VARIABLE compute_capabilities)
endif()
# Filter unrelated content out of the output.
string(REGEX MATCHALL "[0-9]+\\.[0-9]+" compute_capabilities "${compute_capabilities}")
if(run_result EQUAL 0)
string(REPLACE "2.1" "2.1(2.0)" compute_capabilities "${compute_capabilities}")
set(CUDA_GPU_DETECT_OUTPUT ${compute_capabilities}
CACHE INTERNAL "Returned GPU architectures from detect_gpus tool" FORCE)
endif()
endif()
if(NOT CUDA_GPU_DETECT_OUTPUT)
message(STATUS "Automatic GPU detection failed. Building for common architectures.")
set(${OUT_VARIABLE} ${CUDA_COMMON_GPU_ARCHITECTURES} PARENT_SCOPE)
else()
# Filter based on CUDA version supported archs
set(CUDA_GPU_DETECT_OUTPUT_FILTERED "")
separate_arguments(CUDA_GPU_DETECT_OUTPUT)
foreach(ITEM IN ITEMS ${CUDA_GPU_DETECT_OUTPUT})
if(CUDA_LIMIT_GPU_ARCHITECTURE AND (ITEM VERSION_GREATER CUDA_LIMIT_GPU_ARCHITECTURE OR
ITEM VERSION_EQUAL CUDA_LIMIT_GPU_ARCHITECTURE))
list(GET CUDA_COMMON_GPU_ARCHITECTURES -1 NEWITEM)
string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${NEWITEM}")
else()
string(APPEND CUDA_GPU_DETECT_OUTPUT_FILTERED " ${ITEM}")
endif()
endforeach()
set(${OUT_VARIABLE} ${CUDA_GPU_DETECT_OUTPUT_FILTERED} PARENT_SCOPE)
endif()
endfunction()
################################################################################################
# Function for selecting GPU arch flags for nvcc based on CUDA architectures from parameter list
# Usage:
# SELECT_NVCC_ARCH_FLAGS(out_variable [list of CUDA compute archs])
function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
set(CUDA_ARCH_LIST "${ARGN}")
if("X${CUDA_ARCH_LIST}" STREQUAL "X" )
set(CUDA_ARCH_LIST "Auto")
endif()
set(cuda_arch_bin)
set(cuda_arch_ptx)
if("${CUDA_ARCH_LIST}" STREQUAL "All")
set(CUDA_ARCH_LIST ${CUDA_KNOWN_GPU_ARCHITECTURES})
elseif("${CUDA_ARCH_LIST}" STREQUAL "Common")
set(CUDA_ARCH_LIST ${CUDA_COMMON_GPU_ARCHITECTURES})
elseif("${CUDA_ARCH_LIST}" STREQUAL "Auto")
CUDA_DETECT_INSTALLED_GPUS(CUDA_ARCH_LIST)
message(STATUS "Autodetected CUDA architecture(s): ${CUDA_ARCH_LIST}")
endif()
# Now process the list and look for names
string(REGEX REPLACE "[ \t]+" ";" CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
list(REMOVE_DUPLICATES CUDA_ARCH_LIST)
foreach(arch_name ${CUDA_ARCH_LIST})
set(arch_bin)
set(arch_ptx)
set(add_ptx FALSE)
# Check to see if we are compiling PTX
if(arch_name MATCHES "(.*)\\+PTX$")
set(add_ptx TRUE)
set(arch_name ${CMAKE_MATCH_1})
endif()
if(arch_name MATCHES "^([0-9]\\.[0-9](\\([0-9]\\.[0-9]\\))?)$")
set(arch_bin ${CMAKE_MATCH_1})
set(arch_ptx ${arch_bin})
else()
# Look for it in our list of known architectures
if(${arch_name} STREQUAL "Kepler+Tesla")
set(arch_bin 3.7)
elseif(${arch_name} STREQUAL "Kepler")
set(arch_bin 3.5)
set(arch_ptx 3.5)
elseif(${arch_name} STREQUAL "Maxwell+Tegra")
set(arch_bin 5.3)
elseif(${arch_name} STREQUAL "Maxwell")
set(arch_bin 5.0 5.2)
set(arch_ptx 5.2)
elseif(${arch_name} STREQUAL "Pascal")
set(arch_bin 6.0 6.1)
set(arch_ptx 6.1)
elseif(${arch_name} STREQUAL "Volta")
set(arch_bin 7.0 7.0)
set(arch_ptx 7.0)
elseif(${arch_name} STREQUAL "Turing")
set(arch_bin 7.5)
set(arch_ptx 7.5)
elseif(${arch_name} STREQUAL "Ampere")
set(arch_bin 8.0)
set(arch_ptx 8.0)
else()
message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS")
endif()
endif()
if(NOT arch_bin)
message(SEND_ERROR "arch_bin wasn't set for some reason")
endif()
list(APPEND cuda_arch_bin ${arch_bin})
if(add_ptx)
if (NOT arch_ptx)
set(arch_ptx ${arch_bin})
endif()
list(APPEND cuda_arch_ptx ${arch_ptx})
endif()
endforeach()
# remove dots and convert to lists
string(REGEX REPLACE "\\." "" cuda_arch_bin "${cuda_arch_bin}")
string(REGEX REPLACE "\\." "" cuda_arch_ptx "${cuda_arch_ptx}")
string(REGEX MATCHALL "[0-9()]+" cuda_arch_bin "${cuda_arch_bin}")
string(REGEX MATCHALL "[0-9]+" cuda_arch_ptx "${cuda_arch_ptx}")
if(cuda_arch_bin)
list(REMOVE_DUPLICATES cuda_arch_bin)
endif()
if(cuda_arch_ptx)
list(REMOVE_DUPLICATES cuda_arch_ptx)
endif()
set(nvcc_flags "")
set(nvcc_archs_readable "")
# Tell NVCC to add binaries for the specified GPUs
foreach(arch ${cuda_arch_bin})
if(arch MATCHES "([0-9]+)\\(([0-9]+)\\)")
# User explicitly specified ARCH for the concrete CODE
list(APPEND nvcc_flags -gencode arch=compute_${CMAKE_MATCH_2},code=sm_${CMAKE_MATCH_1})
list(APPEND nvcc_archs_readable sm_${CMAKE_MATCH_1})
else()
# User didn't explicitly specify ARCH for the concrete CODE, we assume ARCH=CODE
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=sm_${arch})
list(APPEND nvcc_archs_readable sm_${arch})
endif()
endforeach()
# Tell NVCC to add PTX intermediate code for the specified architectures
foreach(arch ${cuda_arch_ptx})
list(APPEND nvcc_flags -gencode arch=compute_${arch},code=compute_${arch})
list(APPEND nvcc_archs_readable compute_${arch})
endforeach()
string(REPLACE ";" " " nvcc_archs_readable "${nvcc_archs_readable}")
set(${out_variable} ${nvcc_flags} PARENT_SCOPE)
set(${out_variable}_readable ${nvcc_archs_readable} PARENT_SCOPE)
endfunction()
# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
# PYTHON_EXECUTABLE is set by pybind11.cmake
message(STATUS "Python executable: ${PYTHON_EXECUTABLE}")
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import os; import torch; print(os.path.dirname(torch.__file__))"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE TORCH_DIR
)
list(APPEND CMAKE_PREFIX_PATH "${TORCH_DIR}")
find_package(Torch REQUIRED)
# set the global CMAKE_CXX_FLAGS so that
# optimized_transducer uses the same abi flag as PyTorch
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
if(OT_WITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${TORCH_CXX_FLAGS}")
endif()
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[0])"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE OT_TORCH_VERSION_MAJOR
)
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__.split('.')[1])"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE OT_TORCH_VERSION_MINOR
)
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__)"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE TORCH_VERSION
)
message(STATUS "PyTorch version: ${TORCH_VERSION}")
if(OT_WITH_CUDA)
execute_process(
COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.version.cuda)"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE TORCH_CUDA_VERSION
)
message(STATUS "PyTorch cuda version: ${TORCH_CUDA_VERSION}")
if(NOT CUDA_VERSION VERSION_EQUAL TORCH_CUDA_VERSION)
message(FATAL_ERROR
"PyTorch ${TORCH_VERSION} is compiled with CUDA ${TORCH_CUDA_VERSION}.\n"
"But you are using CUDA ${CUDA_VERSION} to compile optimized_transducer.\n"
"Please try to use the same CUDA version for PyTorch and optimized_transducer.\n"
"**You can remove this check if you are sure this will not cause "
"problems**\n"
)
endif()
# Solve the following error for NVCC:
# unknown option `-Wall`
#
# It contains only some -Wno-* flags, so it is OK
# to set them to empty
set_property(TARGET torch_cuda
PROPERTY
INTERFACE_COMPILE_OPTIONS ""
)
set_property(TARGET torch_cpu
PROPERTY
INTERFACE_COMPILE_OPTIONS ""
)
endif()
add_subdirectory(csrc)
add_subdirectory(python)
include_directories(${CMAKE_SOURCE_DIR})
set(srcs
mutual_information_cpu.cc
)
add_library(mutual_information_core ${srcs})
target_link_libraries(mutual_information_core PUBLIC ${TORCH_LIBRARIES})
if(FT_WITH_CUDA)
set(cuda_srcs mutual_information_cuda.cu)
add_library(mutual_information_core_cuda ${cuda_srcs})
target_link_libraries(mutual_information_core_cuda PUBLIC ${TORCH_LIBRARIES})
target_include_directories(mutual_information_core_cuda PUBLIC ${PYTHON_INCLUDE_DIRS})
target_link_libraries(mutual_information_core PUBLIC mutual_information_core_cuda)
endif()
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#define FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
#include <torch/extension.h>
#include <cmath>
#include <vector>
#ifdef __CUDA_ARCH__
#define FT_CUDA_HOSTDEV __host__ __device__
#else
#define FT_CUDA_HOSTDEV
#endif
namespace fast_rnnt {
FT_CUDA_HOSTDEV inline double LogAdd(double x, double y) {
double diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
// returns log(exp(x) + exp(y)).
FT_CUDA_HOSTDEV inline float LogAdd(float x, float y) {
float diff;
if (x < y) {
diff = x - y;
x = y;
} else {
diff = y - x;
}
// diff is negative. x is now the larger one.
if (diff - diff != 0)
return x; // x and y are probably -inf. Return the larger one.
else
return x + log1p(exp(diff));
}
/*
Forward of mutual_information. See also comment of `mutual_information`
in ../pyhton/fast_rnnt/mutual_information.py. This is the core recursion
in the sequence-to-sequence mutual information computation.
@param px Tensor of shape [B][S][T + 1] if not modified, [B][S][T] if
modified. `modified` can be worked out from this. In not-modified case,
it can be thought of as the log-odds ratio of generating the next x in
the sequence, i.e.
xy[b][s][t] is the log of
p(x_s | x_0..x_{s-1}, y_0..y_{s-1}) / p(x_s),
i.e. the log-prob of generating x_s given subsequences of
lengths (s, t), divided by the prior probability of generating x_s.
(See mutual_information.py for more info).
@param py The log-odds ratio of generating the next y in the sequence.
Shape [B][S + 1][T]
@param p This function writes to p[b][s][t] the mutual information between
sub-sequences of x and y of length s and t respectively, from the
b'th sequences in the batch. Its shape is [B][S + 1][T + 1].
Concretely, this function implements the following recursion,
in the case where s_begin == t_begin == 0:
p[b,0,0] = 0.0
if not modified:
p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
p[b,s,t-1] + py[b,s,t-1])
if modified:
p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
p[b,s,t-1] + py[b,s,t-1])
... treating values with any -1 index as -infinity.
.. if `boundary` is set, we start fom p[b,s_begin,t_begin]=0.0.
@param boundary If set, a tensor of shape [B][4] of type int64_t, which
contains, where for each batch element b, boundary[b]
equals [s_begin, t_begin, s_end, t_end]
which are the beginning and end (i.e. one-past-the-last)
of the x and y sequences that we should process.
Alternatively, may be a tensor of shape [0][0] and type
int64_t; the elements will default to (0, 0, S, T).
@return A tensor `ans` of shape [B], where this function will set
ans[b] = p[b][s_end][t_end],
with s_end and t_end being (S, T) if `boundary` was specified,
and (boundary[b][2], boundary[b][3]) otherwise.
`ans` represents the mutual information between each pair of
sequences (i.e. x[b] and y[b], although the sequences are not
supplied directy to this function).
The block-dim and grid-dim must both be 1-dimensional, and the block-dim must
be at least 128.
*/
torch::Tensor MutualInformationCpu(
torch::Tensor px, // [B][S][T+1]
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
torch::Tensor MutualInformationCuda(
torch::Tensor px, // [B][S][T+1] if !modified, [B][S][T] if modified.
torch::Tensor py, // [B][S+1][T]
torch::optional<torch::Tensor> boundary, // [B][4], int64_t.
torch::Tensor p); // [B][S+1][T+1]; an output
/*
backward of mutual_information; returns (grad_px, grad_py)
if overwrite_ans_grad == true, this function will overwrite ans_grad with a
value that, if the computation worked correctly, should be identical to or
very close to the value of ans_grad at entry. This can be used
to validate the correctness of this code.
*/
std::vector<torch::Tensor>
MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad);
std::vector<torch::Tensor> MutualInformationBackwardCuda(
torch::Tensor px, torch::Tensor py, torch::optional<torch::Tensor> boundary,
torch::Tensor p, torch::Tensor ans_grad, bool overwrite_ans_grad);
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_MUTUAL_INFORMATION_H_
/**
* @copyright
* Copyright 2021 Xiaomi Corporation (authors: Daniel Povey)
*
* @copyright
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fast_rnnt/csrc/mutual_information.h"
namespace fast_rnnt {
// forward of mutual_information. See """... """ comment of
// `mutual_information_recursion` in
// in k2/python/k2/mutual_information.py for documentation of the
// behavior of this function.
// px: of shape [B, S, T+1] if !modified, else [B, S, T] <-- work out
// `modified` from this.
// py: of shape [B, S+1, T]
// boundary: of shape [B, 4], containing (s_begin, t_begin, s_end, t_end)
// defaulting to (0, 0, S, T).
// p: of shape (S+1, T+1)
// Computes the recursion:
// if !modified:
// p[b,s,t] = log_add(p[b,s-1,t] + px[b,s-1,t],
// p[b,s,t-1] + py[b,s,t-1])
// if modified:
// p[b,s,t] = log_add(p[b,s-1,t-1] + px[b,s-1,t-1],
// p[b,s,t-1] + py[b,s,t-1])
// .. treating out-of-range elements as -infinity and with special cases:
// p[b, s_begin, t_begin] = 0.0
//
// and this function returns a tensor of shape (B,) consisting of elements
// p[b, s_end, t_end]
torch::Tensor MutualInformationCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> opt_boundary,
torch::Tensor p) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() &&
p.device().is_cpu(),
"inputs must be CPU tensors");
bool modified = (px.size(2) == py.size(2));
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), S = px.size(1), T = py.size(2);
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1 && py.size(2) == T);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
auto boundary = opt_boundary.value_or(
torch::tensor({0, 0, S, T},
torch::dtype(torch::kInt64).device(torch::kCPU))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);
torch::Tensor ans = torch::empty({B}, opts);
AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cpu_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(),
py_a = py.accessor<scalar_t, 3>(), p_a = p.accessor<scalar_t, 3>();
auto boundary_a = boundary.accessor<int64_t, 2>();
auto ans_a = ans.accessor<scalar_t, 1>();
int t_offset = (modified ? -1 : 0);
for (int b = 0; b < B; b++) {
int s_begin = boundary_a[b][0];
int t_begin = boundary_a[b][1];
int s_end = boundary_a[b][2];
int t_end = boundary_a[b][3];
p_a[b][s_begin][t_begin] = 0.0;
if (modified) {
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] = -std::numeric_limits<scalar_t>::infinity();
} else {
// note: t_offset = 0 so don't need t_begin + t_offset below.
for (int s = s_begin + 1; s <= s_end; ++s)
p_a[b][s][t_begin] =
p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
}
for (int t = t_begin + 1; t <= t_end; ++t)
p_a[b][s_begin][t] =
p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
for (int s = s_begin + 1; s <= s_end; ++s) {
scalar_t p_s_t1 = p_a[b][s][t_begin];
for (int t = t_begin + 1; t <= t_end; ++t) {
// The following statement is a small optimization of:
// p_a[b][s][t] = LogAdd(
// p_a[b][s - 1][t + t_offset] + px_a[b][s -1][t + t_offset],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
p_a[b][s][t] = p_s_t1 = LogAdd(p_a[b][s - 1][t + t_offset] +
px_a[b][s - 1][t + t_offset],
p_s_t1 + py_a[b][s][t - 1]);
}
}
ans_a[b] = p_a[b][s_end][t_end];
}
}));
return ans;
}
// backward of mutual_information. Returns (px_grad, py_grad).
// p corresponds to what we computed in the forward pass.
std::vector<torch::Tensor>
MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> opt_boundary,
torch::Tensor p, torch::Tensor ans_grad) {
TORCH_CHECK(px.dim() == 3, "px must be 3-dimensional");
TORCH_CHECK(py.dim() == 3, "py must be 3-dimensional.");
TORCH_CHECK(p.dim() == 3, "p must be 3-dimensional.");
TORCH_CHECK(ans_grad.dim() == 1, "ans_grad must be 1-dimensional.");
bool modified = (px.size(2) == py.size(2));
TORCH_CHECK(px.device().is_cpu() && py.device().is_cpu() &&
p.device().is_cpu() && ans_grad.device().is_cpu(),
"inputs must be CPU tensors");
auto scalar_t = px.scalar_type();
auto opts = torch::TensorOptions().dtype(scalar_t).device(px.device());
const int B = px.size(0), S = px.size(1), T = py.size(2);
TORCH_CHECK(px.size(2) == (modified ? T : T + 1));
TORCH_CHECK(py.size(0) == B && py.size(1) == S + 1);
TORCH_CHECK(p.size(0) == B && p.size(1) == S + 1 && p.size(2) == T + 1);
auto boundary = opt_boundary.value_or(
torch::tensor({0, 0, S, T},
torch::dtype(torch::kInt64).device(torch::kCPU))
.reshape({1, 4})
.expand({B, 4}));
TORCH_CHECK(boundary.dim() == 2, "boundary must be 2-dimensional.");
TORCH_CHECK(boundary.size(0) == B && boundary.size(1) == 4);
TORCH_CHECK(boundary.device().is_cpu() && boundary.dtype() == torch::kInt64);
bool has_boundary = opt_boundary.has_value();
int T1 = T + (modified ? 0 : 1);
torch::Tensor p_grad = torch::zeros({B, S + 1, T + 1}, opts),
px_grad = (has_boundary ? torch::zeros({B, S, T1}, opts)
: torch::empty({B, S, T1}, opts)),
py_grad = (has_boundary ? torch::zeros({B, S + 1, T}, opts)
: torch::empty({B, S + 1, T}, opts));
AT_DISPATCH_FLOATING_TYPES(
px.scalar_type(), "mutual_information_cpu_backward_loop", ([&] {
auto px_a = px.accessor<scalar_t, 3>(), p_a = p.accessor<scalar_t, 3>(),
p_grad_a = p_grad.accessor<scalar_t, 3>(),
px_grad_a = px_grad.accessor<scalar_t, 3>(),
py_grad_a = py_grad.accessor<scalar_t, 3>();
auto ans_grad_a = ans_grad.accessor<scalar_t, 1>();
auto boundary_a = boundary.accessor<int64_t, 2>();
int t_offset = (modified ? -1 : 0);
for (int b = 0; b < B; b++) {
int s_begin = boundary_a[b][0];
int t_begin = boundary_a[b][1];
int s_end = boundary_a[b][2];
int t_end = boundary_a[b][3];
// Backprop for: ans_a[b] = p_a[b][s_end][t_end];
p_grad_a[b][s_end][t_end] = ans_grad_a[b];
for (int s = s_end; s > s_begin; --s) {
for (int t = t_end; t > t_begin; --t) {
// The s,t indexes correspond to
// The statement we are backpropagating here is:
// p_a[b][s][t] = LogAdd(
// p_a[b][s - 1][t + t_offset] + px_a[b][s - 1][t + t_offset],
// p_a[b][s][t - 1] + py_a[b][s][t - 1]);
// .. which obtains p_a[b][s][t - 1] from a register.
scalar_t term1 = p_a[b][s - 1][t + t_offset] +
px_a[b][s - 1][t + t_offset],
// term2 = p_a[b][s][t - 1] + py_a[b][s][t - 1], <-- not
// actually needed..
total = p_a[b][s][t];
if (total - total != 0)
total = 0;
scalar_t term1_deriv = exp(term1 - total),
term2_deriv = 1.0 - term1_deriv,
grad = p_grad_a[b][s][t];
scalar_t term1_grad, term2_grad;
if (term1_deriv - term1_deriv == 0.0) {
term1_grad = term1_deriv * grad;
term2_grad = term2_deriv * grad;
} else {
// could happen if total == -inf
term1_grad = term2_grad = 0.0;
}
px_grad_a[b][s - 1][t + t_offset] = term1_grad;
p_grad_a[b][s - 1][t + t_offset] = term1_grad;
py_grad_a[b][s][t - 1] = term2_grad;
p_grad_a[b][s][t - 1] += term2_grad;
}
}
for (int t = t_end; t > t_begin; --t) {
// Backprop for:
// p_a[b][s_begin][t] =
// p_a[b][s_begin][t - 1] + py_a[b][s_begin][t - 1];
scalar_t this_p_grad = p_grad_a[b][s_begin][t];
p_grad_a[b][s_begin][t - 1] += this_p_grad;
py_grad_a[b][s_begin][t - 1] = this_p_grad;
}
if (!modified) {
for (int s = s_end; s > s_begin; --s) {
// Backprop for:
// p_a[b][s][t_begin] =
// p_a[b][s - 1][t_begin] + px_a[b][s - 1][t_begin];
scalar_t this_p_grad = p_grad_a[b][s][t_begin];
p_grad_a[b][s - 1][t_begin] += this_p_grad;
px_grad_a[b][s - 1][t_begin] = this_p_grad;
}
} // else these were all -infinity's and there is nothing to
// backprop.
// There is no backprop for:
// p_a[b][s_begin][t_begin] = 0.0;
// .. but we can use this for a check, that the grad at the beginning
// of the sequence is equal to the grad at the end of the sequence.
if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b];
if (fabs(grad_ratio - 1.0) > 0.01) {
// K2_LOG(WARNING)
//<< "Warning: mutual_information backprop: expected these "
//<< "numbers to be the same:"
//<< static_cast<float>(p_grad_a[b][s_begin][t_begin]) << " vs "
//<< static_cast<float>(ans_grad_a[b]);
}
}
}
}));
return std::vector<torch::Tensor>({px_grad, py_grad});
}
} // namespace fast_rnnt
add_subdirectory(csrc)
add_subdirectory(tests)
include_directories(${CMAKE_SOURCE_DIR})
pybind11_add_module(_fast_rnnt
mutual_information.cu
)
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)
if(UNIX AND NOT APPLE)
target_link_libraries(_fast_rnnt
PRIVATE
${PYTHON_LIBRARY}
${TORCH_DIR}/lib/libtorch_python.so
)
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment