Commit 0818d487 authored by pkufool's avatar pkufool
Browse files

Update Readme

parent 912b5f55
This project implements a method for faster and more memory-efficient RNN-T computation, called `pruned rnnt`.
This contains an implementation of a particular recursion that turns out to be useful for fast RNN-T computation. ## How does the pruned-rnnt work ?
We need to update this README.
\ No newline at end of file We first obtain pruning bounds for the RNN-T recursion using
a simple joiner network that is just an addition of the encoder and decoder,
then we use those pruning bounds to evaluate the full, non-linear joiner network.
The picture below display the gradients (obtained by rnnt_loss_simple with return_grad equals to true)
of transducer lattice node, at each time frame, only a small set of nodes have a non-zero gradient,
which justifies the pruned RNN-T loss, i.e., putting a limit on the number of symbols per frame.
<img src="https://user-images.githubusercontent.com/5284924/158116784-4dcf1107-2b84-4c0c-90c3-cb4a02f027c9.png" width="500" height="125" />
> This picture is taken from [here](https://github.com/k2-fsa/icefall/pull/251)
## Installation
You can install it via `pip`:
```
pip install fast_rnnt
```
You can also install from source:
```
git clone https://github.com/danpovey/fast_rnnt.git
cd fast_rnnt
python setup.py install
```
To check that `fast_rnnt` was installed successfully, please run
```
python3 -c "import fast_rnnt; print(fast_rnnt.__version__)"
```
which should print the version of the installed `fast_rnnt`, e.g., `1.0`.
### How to display installation log ?
Use
```
pip install --verbose fast_rnnt
```
### How to reduce installation time ?
Use
```
export FT_MAKE_ARGS="-j"
pip install --verbose fast_rnnt
```
It will pass `-j` to `make`.
### Which version of PyTorch is supported ?
It has been tested on PyTorch >= 1.5.0.
Note: The cuda version of the Pytorch should be the same as the cuda version in your environment,
or it will cause a compilation error.
### How to install a CPU version of `fast_rnnt` ?
Use
```
export FT_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=Release -DFT_WITH_CUDA=OFF"
export FT_MAKE_ARGS="-j"
pip install --verbose fast_rnnt
```
It will pass `-DCMAKE_BUILD_TYPE=Release -DFT_WITH_CUDA=OFF` to `cmake`.
### Where to get help if I have problems with the installation ?
Please file an issue at <https://github.com/danpovey/fast_rnnt/issues>
and describe your problem there.
## Usage
### For rnnt_loss_simple
This is a simple case of the RNN-T loss, where the 'joiner' network is just
addition.
```python
am = torch.randn((B, T, C), dtype=torch.float32)
lm = torch.randn((B, S + 1, C), dtype=torch.float32)
symbols = torch.randint(0, C, (B, S))
termination_symbol = 0
boundary = torch.zeros((B, 4), dtype=torch.int64)
boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames
loss = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
reduction="sum",
)
```
### For rnnt_loss_smoothed
The same as `rnnt_loss_simple`, except that it supports am_only & lm_only smoothing
that allows you to make the loss-function one of the form::
lm_only_scale * lm_probs +
am_only_scale * am_probs +
(1-lm_only_scale-am_only_scale) * combined_probs
where lm_probs and am_probs are the probabilities given the lm and acoustic model independently.
```python
am = torch.randn((B, T, C), dtype=torch.float32)
lm = torch.randn((B, S + 1, C), dtype=torch.float32)
symbols = torch.randint(0, C, (B, S))
termination_symbol = 0
boundary = torch.zeros((B, 4), dtype=torch.int64)
boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames
loss = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.25,
am_only_scale=0.0
boundary=boundary,
reduction="sum",
)
```
### For rnnt_loss_pruned
`rnnt_loss_pruned` can not be used alone, it needs the gradients returned by
`rnnt_loss_simple/rnnt_loss_smoothed` to get pruning bounds.
```python
am = torch.randn((B, T, C), dtype=torch.float32)
lm = torch.randn((B, S + 1, C), dtype=torch.float32)
symbols = torch.randint(0, C, (B, S))
termination_symbol = 0
boundary = torch.zeros((B, 4), dtype=torch.int64)
boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames
# rnnt_loss_simple can be also rnnt_loss_smoothed
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
reduction="sum",
return_grad=True,
)
s_range = 5 # can be other values
ranges = fast_rnnt.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=s_range,
)
am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(am=am, lm=lm, ranges=ranges)
logits = model.joiner(am_pruned, lm_pruned)
pruned_loss = fast_rnnt.rnnt_loss_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=termination_symbol,
boundary=boundary,
reduction="sum",
)
```
You can also find recipes [here](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless) that uses `rnnt_loss_pruned` to train a model.
### For rnnt_loss
The unprund rnnt_loss is the same as torchaudio rnnt_loss, it produces same output as
torchaudio for the same input.
```python
logits = torch.randn((B, S, T, C), dtype=torch.float32)
symbols = torch.randint(0, C, (B, S))
termination_symbol = 0
boundary = torch.zeros((B, 4), dtype=torch.int64)
boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames
loss = fast_rnnt.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
reduction="sum",
)
```
## Benchmarking
The [repo](https://github.com/csukuangfj/transducer-loss-benchmarking) compares the speed
and memory usage of several transducer losses, the summary in the following table is taken
from there, you can check the repository for more details.
|Name |Average step time (us) | Peak memory usage (MB)|
|--------------------|-----------------------|-----------------------|
|torchaudio |601447 |12959.2 |
|k2 |274407 |15106.5 |
|k2 pruned |38112 |2647.8 |
|optimized_transducer|567684 |10903.1 |
|warprnnt_numba |229340 |13061.8 |
|warp-transducer |210772 |13061.8 |
# Try to find Valgrind headers and libraries.
#
# Usage of this module as follows:
# find_package(Valgrind)
#
# Variables used by this module, they can change the default behaviour and need
# to be set before calling find_package:
#
# VALGRIND_ROOT Set this variable to the root installation of valgrind if the
# module has problems finding the proper installation path.
#
# Variables defined by this module:
# Valgrind_FOUND System has valgrind
# Valgrind_INCLUDE_DIR where to find valgrind/memcheck.h, etc.
# Valgrind_EXECUTABLE the valgrind executable.
# Get hint from environment variable (if any)
if(NOT VALGRIND_ROOT AND DEFINED ENV{VALGRIND_ROOT})
set(VALGRIND_ROOT "$ENV{VALGRIND_ROOT}" CACHE PATH "Valgrind base directory location (optional, used for nonstandard installation paths)")
mark_as_advanced(VALGRIND_ROOT)
endif()
# Search path for nonstandard locations
if(VALGRIND_ROOT)
set(Valgrind_INCLUDE_PATH PATHS "${VALGRIND_ROOT}/include" NO_DEFAULT_PATH)
set(Valgrind_BINARY_PATH PATHS "${VALGRIND_ROOT}/bin" NO_DEFAULT_PATH)
endif()
find_path(Valgrind_INCLUDE_DIR valgrind HINTS ${Valgrind_INCLUDE_PATH})
find_program(Valgrind_EXECUTABLE NAMES valgrind PATH ${Valgrind_BINARY_PATH})
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Valgrind DEFAULT_MSG Valgrind_INCLUDE_DIR Valgrind_EXECUTABLE)
mark_as_advanced(Valgrind_INCLUDE_DIR Valgrind_EXECUTABLE)
if(NOT Valgrind_FOUND)
if(Valgrind_FIND_REQUIRED)
message(FATAL_ERROR "Valgrind required but it seems it has not be installed.")
endif()
else()
message(STATUS "Found Valgrind: ${Valgrind_EXECUTABLE}")
endif()
...@@ -42,7 +42,7 @@ class TestMutualInformation(unittest.TestCase): ...@@ -42,7 +42,7 @@ class TestMutualInformation(unittest.TestCase):
cls.dtypes = [torch.float32, torch.float64] cls.dtypes = [torch.float32, torch.float64]
def test_mutual_information_basic(self): def test_mutual_information_basic(self):
for _iter in range(100): for _iter in range(10):
(B, S, T) = ( (B, S, T) = (
random.randint(1, 10), random.randint(1, 10),
random.randint(1, 16), random.randint(1, 16),
...@@ -187,7 +187,7 @@ class TestMutualInformation(unittest.TestCase): ...@@ -187,7 +187,7 @@ class TestMutualInformation(unittest.TestCase):
) )
def test_mutual_information_deriv(self): def test_mutual_information_deriv(self):
for _iter in range(100): for _iter in range(10):
(B, S, T) = ( (B, S, T) = (
random.randint(1, 100), random.randint(1, 100),
random.randint(1, 200), random.randint(1, 200),
......
...@@ -113,6 +113,7 @@ setuptools.setup( ...@@ -113,6 +113,7 @@ setuptools.setup(
}, },
packages=[package_name], packages=[package_name],
url="https://github.com/danpovey/fast_rnnt", url="https://github.com/danpovey/fast_rnnt",
description="Fast and memory-efficient RNN-T loss.",
long_description=read_long_description(), long_description=read_long_description(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
install_requires=get_requirements(), install_requires=get_requirements(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment