Commit 8a6c12ff authored by pkufool's avatar pkufool
Browse files

Minor fixes

parent 90edd9ab
......@@ -99,12 +99,12 @@ boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames
loss = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
reduction="sum",
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
reduction="sum",
)
```
......@@ -129,15 +129,15 @@ boundary = torch.zeros((B, 4), dtype=torch.int64)
boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames
loss = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.25,
am_only_scale=0.0
boundary=boundary,
reduction="sum",
loss = fast_rnnt.rnnt_loss_smoothed(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
lm_only_scale=0.25,
am_only_scale=0.0
boundary=boundary,
reduction="sum",
)
```
......@@ -157,13 +157,13 @@ boundary[:, 3] = num_frames
# rnnt_loss_simple can be also rnnt_loss_smoothed
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
reduction="sum",
return_grad=True,
lm=lm,
am=am,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
reduction="sum",
return_grad=True,
)
s_range = 5 # can be other values
ranges = fast_rnnt.get_rnnt_prune_ranges(
......@@ -203,11 +203,11 @@ boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames
loss = fast_rnnt.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
reduction="sum",
logits=logits,
symbols=symbols,
termination_symbol=termination_symbol,
boundary=boundary,
reduction="sum",
)
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment