• Caroline Chen's avatar
    Enable log probs input for rnnt loss (#2798) · ca478823
    Caroline Chen authored
    Summary:
    Add `fused_log_softmax` argument (default/current behavior = True) to rnnt loss.
    
    If setting it to `False`, call `log_softmax` on the logits prior to passing it in to the rnnt loss function.
    
    The following should produce the same output:
    ```
    rnnt_loss(logits, targets, logit_lengths, target_lengths, fused_log_softmax=True)
    ```
    
    ```
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
    rnnt_loss(log_probs, targets, logit_lengths, target_lengths, fused_log_softmax=False)
    ```
    
    testing -- unit tests + get same results on the conformer rnnt recipe
    
    Pull Request resolved: https://github.com/pytorch/audio/pull/2798
    
    Reviewed By: xiaohui-zhang
    
    Differential Revision: D41083523
    
    Pulled By: carolineechen
    
    fbshipit-source-id: e15442ceed1f461bbf06b724aa0561ff8827ad61
    ca478823
functional_impl.py 38.7 KB