Commit 90edd9ab authored by pkufool's avatar pkufool
Browse files

Minor fixes

parent 0818d487
......@@ -3,15 +3,11 @@ This project implements a method for faster and more memory-efficient RNN-T comp
## How does the pruned-rnnt work ?
We first obtain pruning bounds for the RNN-T recursion using
a simple joiner network that is just an addition of the encoder and decoder,
then we use those pruning bounds to evaluate the full, non-linear joiner network.
We first obtain pruning bounds for the RNN-T recursion using a simple joiner network that is just an addition of the encoder and decoder, then we use those pruning bounds to evaluate the full, non-linear joiner network.
The picture below display the gradients (obtained by rnnt_loss_simple with return_grad equals to true)
of transducer lattice node, at each time frame, only a small set of nodes have a non-zero gradient,
which justifies the pruned RNN-T loss, i.e., putting a limit on the number of symbols per frame.
The picture below display the gradients (obtained by `rnnt_loss_simple` with `return_grad=true`) of lattice nodes, at each time frame, only a small set of nodes have a non-zero gradient, which justifies the pruned RNN-T loss, i.e., putting a limit on the number of symbols per frame.
<img src="https://user-images.githubusercontent.com/5284924/158116784-4dcf1107-2b84-4c0c-90c3-cb4a02f027c9.png" width="500" height="125" />
<img src="https://user-images.githubusercontent.com/5284924/158116784-4dcf1107-2b84-4c0c-90c3-cb4a02f027c9.png" width="900" height="250" />
> This picture is taken from [here](https://github.com/k2-fsa/icefall/pull/251)
......@@ -89,7 +85,7 @@ and describe your problem there.
### For rnnt_loss_simple
This is a simple case of the RNN-T loss, where the 'joiner' network is just
This is a simple case of the RNN-T loss, where the joiner network is just
addition.
```python
......@@ -114,14 +110,14 @@ loss = fast_rnnt.rnnt_loss_simple(
### For rnnt_loss_smoothed
The same as `rnnt_loss_simple`, except that it supports am_only & lm_only smoothing
that allows you to make the loss-function one of the form::
The same as `rnnt_loss_simple`, except that it supports `am_only` & `lm_only` smoothing
that allows you to make the loss-function one of the form:
lm_only_scale * lm_probs +
am_only_scale * am_probs +
(1-lm_only_scale-am_only_scale) * combined_probs
where lm_probs and am_probs are the probabilities given the lm and acoustic model independently.
where `lm_probs` and `am_probs` are the probabilities given the lm and acoustic model independently.
```python
am = torch.randn((B, T, C), dtype=torch.float32)
......@@ -147,8 +143,7 @@ loss = fast_rnnt.rnnt_loss_simple(
### For rnnt_loss_pruned
`rnnt_loss_pruned` can not be used alone, it needs the gradients returned by
`rnnt_loss_simple/rnnt_loss_smoothed` to get pruning bounds.
`rnnt_loss_pruned` can not be used alone, it needs the gradients returned by `rnnt_loss_simple/rnnt_loss_smoothed` to get pruning bounds.
```python
am = torch.randn((B, T, C), dtype=torch.float32)
......@@ -196,8 +191,7 @@ You can also find recipes [here](https://github.com/k2-fsa/icefall/tree/master/e
### For rnnt_loss
The unprund rnnt_loss is the same as torchaudio rnnt_loss, it produces same output as
torchaudio for the same input.
The `unprund rnnt_loss` is the same as `torchaudio rnnt_loss`, it produces same output as torchaudio for the same input.
```python
logits = torch.randn((B, S, T, C), dtype=torch.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment