Commit 90edd9ab authored by pkufool's avatar pkufool
Browse files

Minor fixes

parent 0818d487
...@@ -3,15 +3,11 @@ This project implements a method for faster and more memory-efficient RNN-T comp ...@@ -3,15 +3,11 @@ This project implements a method for faster and more memory-efficient RNN-T comp
## How does the pruned-rnnt work ? ## How does the pruned-rnnt work ?
We first obtain pruning bounds for the RNN-T recursion using We first obtain pruning bounds for the RNN-T recursion using a simple joiner network that is just an addition of the encoder and decoder, then we use those pruning bounds to evaluate the full, non-linear joiner network.
a simple joiner network that is just an addition of the encoder and decoder,
then we use those pruning bounds to evaluate the full, non-linear joiner network.
The picture below display the gradients (obtained by rnnt_loss_simple with return_grad equals to true) The picture below display the gradients (obtained by `rnnt_loss_simple` with `return_grad=true`) of lattice nodes, at each time frame, only a small set of nodes have a non-zero gradient, which justifies the pruned RNN-T loss, i.e., putting a limit on the number of symbols per frame.
of transducer lattice node, at each time frame, only a small set of nodes have a non-zero gradient,
which justifies the pruned RNN-T loss, i.e., putting a limit on the number of symbols per frame.
<img src="https://user-images.githubusercontent.com/5284924/158116784-4dcf1107-2b84-4c0c-90c3-cb4a02f027c9.png" width="500" height="125" /> <img src="https://user-images.githubusercontent.com/5284924/158116784-4dcf1107-2b84-4c0c-90c3-cb4a02f027c9.png" width="900" height="250" />
> This picture is taken from [here](https://github.com/k2-fsa/icefall/pull/251) > This picture is taken from [here](https://github.com/k2-fsa/icefall/pull/251)
...@@ -89,7 +85,7 @@ and describe your problem there. ...@@ -89,7 +85,7 @@ and describe your problem there.
### For rnnt_loss_simple ### For rnnt_loss_simple
This is a simple case of the RNN-T loss, where the 'joiner' network is just This is a simple case of the RNN-T loss, where the joiner network is just
addition. addition.
```python ```python
...@@ -114,14 +110,14 @@ loss = fast_rnnt.rnnt_loss_simple( ...@@ -114,14 +110,14 @@ loss = fast_rnnt.rnnt_loss_simple(
### For rnnt_loss_smoothed ### For rnnt_loss_smoothed
The same as `rnnt_loss_simple`, except that it supports am_only & lm_only smoothing The same as `rnnt_loss_simple`, except that it supports `am_only` & `lm_only` smoothing
that allows you to make the loss-function one of the form:: that allows you to make the loss-function one of the form:
lm_only_scale * lm_probs + lm_only_scale * lm_probs +
am_only_scale * am_probs + am_only_scale * am_probs +
(1-lm_only_scale-am_only_scale) * combined_probs (1-lm_only_scale-am_only_scale) * combined_probs
where lm_probs and am_probs are the probabilities given the lm and acoustic model independently. where `lm_probs` and `am_probs` are the probabilities given the lm and acoustic model independently.
```python ```python
am = torch.randn((B, T, C), dtype=torch.float32) am = torch.randn((B, T, C), dtype=torch.float32)
...@@ -147,8 +143,7 @@ loss = fast_rnnt.rnnt_loss_simple( ...@@ -147,8 +143,7 @@ loss = fast_rnnt.rnnt_loss_simple(
### For rnnt_loss_pruned ### For rnnt_loss_pruned
`rnnt_loss_pruned` can not be used alone, it needs the gradients returned by `rnnt_loss_pruned` can not be used alone, it needs the gradients returned by `rnnt_loss_simple/rnnt_loss_smoothed` to get pruning bounds.
`rnnt_loss_simple/rnnt_loss_smoothed` to get pruning bounds.
```python ```python
am = torch.randn((B, T, C), dtype=torch.float32) am = torch.randn((B, T, C), dtype=torch.float32)
...@@ -196,8 +191,7 @@ You can also find recipes [here](https://github.com/k2-fsa/icefall/tree/master/e ...@@ -196,8 +191,7 @@ You can also find recipes [here](https://github.com/k2-fsa/icefall/tree/master/e
### For rnnt_loss ### For rnnt_loss
The unprund rnnt_loss is the same as torchaudio rnnt_loss, it produces same output as The `unprund rnnt_loss` is the same as `torchaudio rnnt_loss`, it produces same output as torchaudio for the same input.
torchaudio for the same input.
```python ```python
logits = torch.randn((B, S, T, C), dtype=torch.float32) logits = torch.randn((B, S, T, C), dtype=torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment