Commit 2cb87c6b authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Refactor RNNT factory function to support num_symbols argument (#2178)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2178

Reviewed By: mthrok

Differential Revision: D33797649

Pulled By: nateanl

fbshipit-source-id: 7a8f54294e7b5bd4d343c8e361e747bfd8b5b603
parent 39fe9df6
...@@ -165,7 +165,7 @@ class RNNTModule(LightningModule): ...@@ -165,7 +165,7 @@ class RNNTModule(LightningModule):
): ):
super().__init__() super().__init__()
self.model = emformer_rnnt_base() self.model = emformer_rnnt_base(num_symbols=4097)
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0) self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.96, patience=0) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.96, patience=0)
......
...@@ -751,9 +751,12 @@ def emformer_rnnt_model( ...@@ -751,9 +751,12 @@ def emformer_rnnt_model(
return RNNT(transcriber, predictor, joiner) return RNNT(transcriber, predictor, joiner)
def emformer_rnnt_base() -> RNNT: def emformer_rnnt_base(num_symbols: int) -> RNNT:
r"""Builds basic version of Emformer RNN-T model. r"""Builds basic version of Emformer RNN-T model.
Args:
num_symbols (int): The size of target token lexicon.
Returns: Returns:
RNNT: RNNT:
Emformer RNN-T model. Emformer RNN-T model.
...@@ -761,7 +764,7 @@ def emformer_rnnt_base() -> RNNT: ...@@ -761,7 +764,7 @@ def emformer_rnnt_base() -> RNNT:
return emformer_rnnt_model( return emformer_rnnt_model(
input_dim=80, input_dim=80,
encoding_dim=1024, encoding_dim=1024,
num_symbols=4097, num_symbols=num_symbols,
segment_length=16, segment_length=16,
right_context_length=4, right_context_length=4,
time_reduction_input_dim=128, time_reduction_input_dim=128,
......
...@@ -3,6 +3,7 @@ import math ...@@ -3,6 +3,7 @@ import math
import pathlib import pathlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from typing import Callable, List, Tuple from typing import Callable, List, Tuple
import torch import torch
...@@ -364,7 +365,7 @@ class RNNTBundle: ...@@ -364,7 +365,7 @@ class RNNTBundle:
EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle( EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
_rnnt_path="emformer_rnnt_base_librispeech.pt", _rnnt_path="emformer_rnnt_base_librispeech.pt",
_rnnt_factory_func=emformer_rnnt_base, _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=4097),
_global_stats_path="global_stats_rnnt_librispeech.json", _global_stats_path="global_stats_rnnt_librispeech.json",
_sp_model_path="spm_bpe_4096_librispeech.model", _sp_model_path="spm_bpe_4096_librispeech.model",
_right_padding=4, _right_padding=4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment