Commit 8a6c12ff authored by pkufool's avatar pkufool
Browse files

Minor fixes

parent 90edd9ab
...@@ -99,12 +99,12 @@ boundary[:, 2] = target_lengths ...@@ -99,12 +99,12 @@ boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames boundary[:, 3] = num_frames
loss = fast_rnnt.rnnt_loss_simple( loss = fast_rnnt.rnnt_loss_simple(
lm=lm, lm=lm,
am=am, am=am,
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
reduction="sum", reduction="sum",
) )
``` ```
...@@ -129,15 +129,15 @@ boundary = torch.zeros((B, 4), dtype=torch.int64) ...@@ -129,15 +129,15 @@ boundary = torch.zeros((B, 4), dtype=torch.int64)
boundary[:, 2] = target_lengths boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames boundary[:, 3] = num_frames
loss = fast_rnnt.rnnt_loss_simple( loss = fast_rnnt.rnnt_loss_smoothed(
lm=lm, lm=lm,
am=am, am=am,
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
lm_only_scale=0.25, lm_only_scale=0.25,
am_only_scale=0.0 am_only_scale=0.0
boundary=boundary, boundary=boundary,
reduction="sum", reduction="sum",
) )
``` ```
...@@ -157,13 +157,13 @@ boundary[:, 3] = num_frames ...@@ -157,13 +157,13 @@ boundary[:, 3] = num_frames
# rnnt_loss_simple can be also rnnt_loss_smoothed # rnnt_loss_simple can be also rnnt_loss_smoothed
simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple( simple_loss, (px_grad, py_grad) = fast_rnnt.rnnt_loss_simple(
lm=lm, lm=lm,
am=am, am=am,
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
reduction="sum", reduction="sum",
return_grad=True, return_grad=True,
) )
s_range = 5 # can be other values s_range = 5 # can be other values
ranges = fast_rnnt.get_rnnt_prune_ranges( ranges = fast_rnnt.get_rnnt_prune_ranges(
...@@ -203,11 +203,11 @@ boundary[:, 2] = target_lengths ...@@ -203,11 +203,11 @@ boundary[:, 2] = target_lengths
boundary[:, 3] = num_frames boundary[:, 3] = num_frames
loss = fast_rnnt.rnnt_loss( loss = fast_rnnt.rnnt_loss(
logits=logits, logits=logits,
symbols=symbols, symbols=symbols,
termination_symbol=termination_symbol, termination_symbol=termination_symbol,
boundary=boundary, boundary=boundary,
reduction="sum", reduction="sum",
) )
``` ```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment