Commit ffc9c8cc authored by Kritika Singh's avatar Kritika Singh Committed by Facebook Github Bot
Browse files

Make CTC work with more encoder-only models

Summary:
Changes include:
1. Added get_normalized_probabilities to the encoder-only base class FairseqEncoderModel
2. Made CTCCriterion work for both batch_first (LSTMSubsampleEncoderModel) and batch_second (LSTMEncoderOnly) encoder types
3. Added tests for different encoder and CTC combinations.

TODO:
CTC still doesn't work for VGGLSTMEncoderModel so I have disabled that. Will debug and send out fix in another diff.

Reviewed By: jay-mahadeokar

Differential Revision: D15158818

fbshipit-source-id: acb484bad705c937d676d2c3dcde3e3562d68ed9
parent e112d501
......@@ -326,6 +326,17 @@ class FairseqEncoderModel(BaseFairseqModel):
"""
return self.encoder(src_tokens, src_lengths)
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
encoder_out = net_output['encoder_out']
if torch.is_tensor(encoder_out):
logits = encoder_out.float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
raise NotImplementedError
def max_positions(self):
"""Maximum length supported by the model."""
return self.encoder.max_positions()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment