Commit 2cb87c6b authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Refactor RNNT factory function to support num_symbols argument (#2178)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2178

Reviewed By: mthrok

Differential Revision: D33797649

Pulled By: nateanl

fbshipit-source-id: 7a8f54294e7b5bd4d343c8e361e747bfd8b5b603
parent 39fe9df6
......@@ -165,7 +165,7 @@ class RNNTModule(LightningModule):
):
super().__init__()
self.model = emformer_rnnt_base()
self.model = emformer_rnnt_base(num_symbols=4097)
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.96, patience=0)
......
......@@ -751,9 +751,12 @@ def emformer_rnnt_model(
return RNNT(transcriber, predictor, joiner)
def emformer_rnnt_base() -> RNNT:
def emformer_rnnt_base(num_symbols: int) -> RNNT:
r"""Builds basic version of Emformer RNN-T model.
Args:
num_symbols (int): The size of target token lexicon.
Returns:
RNNT:
Emformer RNN-T model.
......@@ -761,7 +764,7 @@ def emformer_rnnt_base() -> RNNT:
return emformer_rnnt_model(
input_dim=80,
encoding_dim=1024,
num_symbols=4097,
num_symbols=num_symbols,
segment_length=16,
right_context_length=4,
time_reduction_input_dim=128,
......
......@@ -3,6 +3,7 @@ import math
import pathlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from typing import Callable, List, Tuple
import torch
......@@ -364,7 +365,7 @@ class RNNTBundle:
EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
_rnnt_path="emformer_rnnt_base_librispeech.pt",
_rnnt_factory_func=emformer_rnnt_base,
_rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=4097),
_global_stats_path="global_stats_rnnt_librispeech.json",
_sp_model_path="spm_bpe_4096_librispeech.model",
_right_padding=4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment