Commit ffeba11a authored by mayp777's avatar mayp777
Browse files

UPDATE

parent 29deb085
......@@ -21,11 +21,13 @@ A pre-trained model and associated pipelines are expressed as an instance of ``B
Under the hood, the implementations of ``Bundle`` use components from other ``torchaudio`` modules, such as :mod:`torchaudio.models` and :mod:`torchaudio.transforms`, or even third party libraries like `SentencPiece <https://github.com/google/sentencepiece>`__ and `DeepPhonemizer <https://github.com/as-ideas/DeepPhonemizer>`__. But this implementation detail is abstracted away from library users.
.. _RNNT:
RNN-T Streaming/Non-Streaming ASR
---------------------------------
Interface
^^^^^^^^^
~~~~~~~~~
``RNNTBundle`` defines ASR pipelines and consists of three steps: feature extraction, inference, and de-tokenization.
......@@ -45,7 +47,7 @@ Interface
.. minigallery:: torchaudio.pipelines.RNNTBundle
Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
......@@ -55,11 +57,11 @@ Pretrained Models
EMFORMER_RNNT_BASE_LIBRISPEECH
wav2vec 2.0 / HuBERT - SSL
--------------------------
wav2vec 2.0 / HuBERT / WavLM - SSL
----------------------------------
Interface
^^^^^^^^^
~~~~~~~~~
``Wav2Vec2Bundle`` instantiates models that generate acoustic features that can be used for downstream inference and fine-tuning.
......@@ -73,7 +75,7 @@ Interface
Wav2Vec2Bundle
Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
......@@ -84,15 +86,21 @@ Pretrained Models
WAV2VEC2_LARGE
WAV2VEC2_LARGE_LV60K
WAV2VEC2_XLSR53
WAV2VEC2_XLSR_300M
WAV2VEC2_XLSR_1B
WAV2VEC2_XLSR_2B
HUBERT_BASE
HUBERT_LARGE
HUBERT_XLARGE
WAVLM_BASE
WAVLM_BASE_PLUS
WAVLM_LARGE
wav2vec 2.0 / HuBERT - Fine-tuned ASR
-------------------------------------
Interface
^^^^^^^^^
~~~~~~~~~
``Wav2Vec2ASRBundle`` instantiates models that generate probability distribution over pre-defined labels, that can be used for ASR.
......@@ -110,7 +118,7 @@ Interface
.. minigallery:: torchaudio.pipelines.Wav2Vec2ASRBundle
Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
......@@ -134,7 +142,41 @@ Pretrained Models
HUBERT_ASR_LARGE
HUBERT_ASR_XLARGE
wav2vec 2.0 / HuBERT - Forced Alignment
---------------------------------------
Interface
~~~~~~~~~
``Wav2Vec2FABundle`` bundles pre-trained model and its associated dictionary. Additionally, it supports appending ``star`` token dimension.
.. image:: https://download.pytorch.org/torchaudio/doc-assets/pipelines-wav2vec2fabundle.png
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
Wav2Vec2FABundle
Wav2Vec2FABundle.Tokenizer
Wav2Vec2FABundle.Aligner
.. rubric:: Tutorials using ``Wav2Vec2FABundle``
.. minigallery:: torchaudio.pipelines.Wav2Vec2FABundle
Pertrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
MMS_FA
.. _Tacotron2:
Tacotron2 Text-To-Speech
------------------------
......@@ -147,7 +189,7 @@ Tacotron2 Text-To-Speech
Similarly ``Vocoder`` can be an algorithm without learning parameters, like `Griffin-Lim`, or a neural-network-based model like `Waveglow`.
Interface
^^^^^^^^^
~~~~~~~~~
.. autosummary::
:toctree: generated
......@@ -163,7 +205,7 @@ Interface
.. minigallery:: torchaudio.pipelines.Tacotron2TTSBundle
Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
......@@ -179,7 +221,7 @@ Source Separation
-----------------
Interface
^^^^^^^^^
~~~~~~~~~
``SourceSeparationBundle`` instantiates source separation models which take single channel audio and generates multi-channel audio.
......@@ -197,7 +239,7 @@ Interface
.. minigallery:: torchaudio.pipelines.SourceSeparationBundle
Pretrained Models
^^^^^^^^^^^^^^^^^
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
......@@ -207,3 +249,53 @@ Pretrained Models
CONVTASNET_BASE_LIBRI2MIX
HDEMUCS_HIGH_MUSDB_PLUS
HDEMUCS_HIGH_MUSDB
Squim Objective
---------------
Interface
~~~~~~~~~
:py:class:`SquimObjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **objecive** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimObjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_OBJECTIVE
Squim Subjective
----------------
Interface
~~~~~~~~~
:py:class:`SquimSubjectiveBundle` defines speech quality and intelligibility measurement (SQUIM) pipeline that can predict **subjective** metric scores given the input waveform.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
SquimSubjectiveBundle
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
SQUIM_SUBJECTIVE
@misc{zeyer2021does,
title={Why does CTC result in peaky behavior?},
author={Albert Zeyer and Ralf Schlüter and Hermann Ney},
year={2021},
eprint={2105.14849},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@article{wavernn,
author = {Nal Kalchbrenner and
Erich Elsen and
......@@ -66,7 +74,7 @@
year = {2017}
}
@misc{conneau2020unsupervised,
title={Unsupervised Cross-lingual Representation Learning for Speech Recognition},
title={Unsupervised Cross-lingual Representation Learning for Speech Recognition},
author={Alexis Conneau and Alexei Baevski and Ronan Collobert and Abdelrahman Mohamed and Michael Auli},
year={2020},
eprint={2006.13979},
......@@ -80,7 +88,7 @@
year={2014}
}
@misc{ardila2020common,
title={Common Voice: A Massively-Multilingual Speech Corpus},
title={Common Voice: A Massively-Multilingual Speech Corpus},
author={Rosana Ardila and Megan Branson and Kelly Davis and Michael Henretty and Michael Kohler and Josh Meyer and Reuben Morais and Lindsay Saunders and Francis M. Tyers and Gregor Weber},
year={2020},
eprint={1912.06670},
......@@ -99,16 +107,16 @@
}
@INPROCEEDINGS{librilight,
author={J. {Kahn} and M. {Rivière} and W. {Zheng} and E. {Kharitonov} and Q. {Xu} and P. E. {Mazaré} and J. {Karadayi} and V. {Liptchinsky} and R. {Collobert} and C. {Fuegen} and T. {Likhomanenko} and G. {Synnaeve} and A. {Joulin} and A. {Mohamed} and E. {Dupoux}},
booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Libri-Light: A Benchmark for ASR with Limited or No Supervision},
booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Libri-Light: A Benchmark for ASR with Limited or No Supervision},
year={2020},
pages={7669-7673},
note = {\url{https://github.com/facebookresearch/libri-light}},
}
@INPROCEEDINGS{7178964,
author={Panayotov, Vassil and Chen, Guoguo and Povey, Daniel and Khudanpur, Sanjeev},
booktitle={2015 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Librispeech: An ASR corpus based on public domain audio books},
booktitle={2015 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Librispeech: An ASR corpus based on public domain audio books},
year={2015},
volume={},
number={},
......@@ -122,7 +130,7 @@
year = {2019},
}
@misc{baevski2020wav2vec,
title={wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations},
title={wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations},
author={Alexei Baevski and Henry Zhou and Abdelrahman Mohamed and Michael Auli},
year={2020},
eprint={2006.11477},
......@@ -130,7 +138,7 @@
primaryClass={cs.CL}
}
@misc{hsu2021hubert,
title={HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units},
title={HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units},
author={Wei-Ning Hsu and Benjamin Bolte and Yao-Hung Hubert Tsai and Kushal Lakhotia and Ruslan Salakhutdinov and Abdelrahman Mohamed},
year={2021},
eprint={2106.07447},
......@@ -138,7 +146,7 @@
primaryClass={cs.CL}
}
@misc{hannun2014deep,
title={Deep Speech: Scaling up end-to-end speech recognition},
title={Deep Speech: Scaling up end-to-end speech recognition},
author={Awni Hannun and Carl Case and Jared Casper and Bryan Catanzaro and Greg Diamos and Erich Elsen and Ryan Prenger and Sanjeev Satheesh and Shubho Sengupta and Adam Coates and Andrew Y. Ng},
year={2014},
eprint={1412.5567},
......@@ -146,7 +154,7 @@
primaryClass={cs.CL}
}
@misc{graves2012sequence,
title={Sequence Transduction with Recurrent Neural Networks},
title={Sequence Transduction with Recurrent Neural Networks},
author={Alex Graves},
year={2012},
eprint={1211.3711},
......@@ -154,7 +162,7 @@
primaryClass={cs.NE}
}
@misc{collobert2016wav2letter,
title={Wav2Letter: an End-to-End ConvNet-based Speech Recognition System},
title={Wav2Letter: an End-to-End ConvNet-based Speech Recognition System},
author={Ronan Collobert and Christian Puhrsch and Gabriel Synnaeve},
year={2016},
eprint={1609.03193},
......@@ -162,7 +170,7 @@
primaryClass={cs.LG}
}
@misc{kalchbrenner2018efficient,
title={Efficient Neural Audio Synthesis},
title={Efficient Neural Audio Synthesis},
author={Nal Kalchbrenner and Erich Elsen and Karen Simonyan and Seb Noury and Norman Casagrande and Edward Lockhart and Florian Stimberg and Aaron van den Oord and Sander Dieleman and Koray Kavukcuoglu},
year={2018},
eprint={1802.08435},
......@@ -202,8 +210,8 @@
}
@INPROCEEDINGS{6701851,
author={Perraudin, Nathanaël and Balazs, Peter and Søndergaard, Peter L.},
booktitle={2013 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics},
title={A fast Griffin-Lim algorithm},
booktitle={2013 IEEE Workshop on Applications of Signal Processing to Audio and Acoustics},
title={A fast Griffin-Lim algorithm},
year={2013},
volume={},
number={},
......@@ -211,8 +219,8 @@
doi={10.1109/WASPAA.2013.6701851}}
@INPROCEEDINGS{1172092,
author={Griffin, D. and Jae Lim},
booktitle={ICASSP '83. IEEE International Conference on Acoustics, Speech, and Signal Processing},
title={Signal estimation from modified short-time Fourier transform},
booktitle={ICASSP '83. IEEE International Conference on Acoustics, Speech, and Signal Processing},
title={Signal estimation from modified short-time Fourier transform},
year={1983},
volume={8},
number={},
......@@ -220,8 +228,8 @@
doi={10.1109/ICASSP.1983.1172092}}
@INPROCEEDINGS{6854049,
author={Ghahremani, Pegah and BabaAli, Bagher and Povey, Daniel and Riedhammer, Korbinian and Trmal, Jan and Khudanpur, Sanjeev},
booktitle={2014 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={A pitch extraction algorithm tuned for automatic speech recognition},
booktitle={2014 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={A pitch extraction algorithm tuned for automatic speech recognition},
year={2014},
volume={},
number={},
......@@ -254,16 +262,16 @@
organization={IEEE}
}
@inproceedings{shi2021emformer,
title={Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition},
title={Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition},
author={Shi, Yangyang and Wang, Yongqiang and Wu, Chunyang and Yeh, Ching-Feng and Chan, Julian and Zhang, Frank and Le, Duc and Seltzer, Mike},
booktitle={ICASSP 2021 - 2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
booktitle={ICASSP 2021 - 2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={6783-6787},
year={2021}
}
@inproceedings{9747706,
author={Shi, Yangyang and Wu, Chunyang and Wang, Dilin and Xiao, Alex and Mahadeokar, Jay and Zhang, Xiaohui and Liu, Chunxi and Li, Ke and Shangguan, Yuan and Nagaraja, Varun and Kalinli, Ozlem and Seltzer, Mike},
booktitle={ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution},
booktitle={ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Streaming Transformer Transducer based Speech Recognition Using Non-Causal Convolution},
year={2022},
volume={},
number={},
......@@ -439,3 +447,154 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop
journal={arXiv preprint arXiv:1805.10190},
year={2018}
}
@INPROCEEDINGS{9746490,
author={Srivastava, Sangeeta and Wang, Yun and Tjandra, Andros and Kumar, Anurag and Liu, Chunxi and Singh, Kritika and Saraf, Yatharth},
booktitle={ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
title={Conformer-Based Self-Supervised Learning For Non-Speech Audio Tasks},
year={2022},
volume={},
number={},
pages={8862-8866},
doi={10.1109/ICASSP43922.2022.9746490}}
@article{chen2022wavlm,
title={Wavlm: Large-scale self-supervised pre-training for full stack speech processing},
author={Chen, Sanyuan and Wang, Chengyi and Chen, Zhengyang and Wu, Yu and Liu, Shujie and Chen, Zhuo and Li, Jinyu and Kanda, Naoyuki and Yoshioka, Takuya and Xiao, Xiong and others},
journal={IEEE Journal of Selected Topics in Signal Processing},
volume={16},
number={6},
pages={1505--1518},
year={2022},
publisher={IEEE}
}
@inproceedings{GigaSpeech2021,
title={GigaSpeech: An Evolving, Multi-domain ASR Corpus with 10,000 Hours of Transcribed Audio},
booktitle={Proc. Interspeech 2021},
year=2021,
author={Guoguo Chen and Shuzhou Chai and Guanbo Wang and Jiayu Du and Wei-Qiang Zhang and Chao Weng and Dan Su and Daniel Povey and Jan Trmal and Junbo Zhang and Mingjie Jin and Sanjeev Khudanpur and Shinji Watanabe and Shuaijiang Zhao and Wei Zou and Xiangang Li and Xuchen Yao and Yongqing Wang and Yujun Wang and Zhao You and Zhiyong Yan}
}
@inproceedings{NEURIPS2020_c5d73680,
author = {Kong, Jungil and Kim, Jaehyeon and Bae, Jaekyoung},
booktitle = {Advances in Neural Information Processing Systems},
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
pages = {17022--17033},
publisher = {Curran Associates, Inc.},
title = {HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis},
url = {https://proceedings.neurips.cc/paper/2020/file/c5d736809766d46260d816d8dbc9eb44-Paper.pdf},
volume = {33},
year = {2020}
}
@inproceedings{ko15_interspeech,
author={Tom Ko and Vijayaditya Peddinti and Daniel Povey and Sanjeev Khudanpur},
title={{Audio augmentation for speech recognition}},
year=2015,
booktitle={Proc. Interspeech 2015},
pages={3586--3589},
doi={10.21437/Interspeech.2015-711}
}
@misc{musan2015,
author = {David Snyder and Guoguo Chen and Daniel Povey},
title = {{MUSAN}: {A} {M}usic, {S}peech, and {N}oise {C}orpus},
year = {2015},
eprint = {1510.08484},
note = {arXiv:1510.08484v1}
}
@article{babu2021xls,
title={XLS-R: Self-supervised cross-lingual speech representation learning at scale},
author={Babu, Arun and Wang, Changhan and Tjandra, Andros and Lakhotia, Kushal and Xu, Qiantong and Goyal, Naman and Singh, Kritika and von Platen, Patrick and Saraf, Yatharth and Pino, Juan and others},
journal={arXiv preprint arXiv:2111.09296},
year={2021}
}
@inproceedings{valk2021voxlingua107,
title={VoxLingua107: a dataset for spoken language recognition},
author={Valk, J{\"o}rgen and Alum{\"a}e, Tanel},
booktitle={2021 IEEE Spoken Language Technology Workshop (SLT)},
pages={652--658},
year={2021},
organization={IEEE}
}
@inproceedings{scheibler2018pyroomacoustics,
title={Pyroomacoustics: A python package for audio room simulation and array processing algorithms},
author={Scheibler, Robin and Bezzam, Eric and Dokmani{\'c}, Ivan},
booktitle={2018 IEEE international conference on acoustics, speech and signal processing (ICASSP)},
pages={351--355},
year={2018},
organization={IEEE}
}
@article{allen1979image,
title={Image method for efficiently simulating small-room acoustics},
author={Allen, Jont B and Berkley, David A},
journal={The Journal of the Acoustical Society of America},
volume={65},
number={4},
pages={943--950},
year={1979},
publisher={Acoustical Society of America}
}
@misc{wiki:Absorption_(acoustics),
author = "{Wikipedia contributors}",
title = "Absorption (acoustics) --- {W}ikipedia{,} The Free Encyclopedia",
url = "https://en.wikipedia.org/wiki/Absorption_(acoustics)",
note = "[Online]"
}
@article{reddy2020interspeech,
title={The interspeech 2020 deep noise suppression challenge: Datasets, subjective testing framework, and challenge results},
author={Reddy, Chandan KA and Gopal, Vishak and Cutler, Ross and Beyrami, Ebrahim and Cheng, Roger and Dubey, Harishchandra and Matusevych, Sergiy and Aichner, Robert and Aazami, Ashkan and Braun, Sebastian and others},
journal={arXiv preprint arXiv:2005.13981},
year={2020}
}
@article{manocha2022speech,
title={Speech quality assessment through MOS using non-matching references},
author={Manocha, Pranay and Kumar, Anurag},
journal={arXiv preprint arXiv:2206.12285},
year={2022}
}
@article{cooper2021voices,
title={How do voices from past speech synthesis challenges compare today?},
author={Cooper, Erica and Yamagishi, Junichi},
journal={arXiv preprint arXiv:2105.02373},
year={2021}
}
@article{mysore2014can,
title={Can we automatically transform speech recorded on common consumer devices in real-world environments into professional production quality speech?—a dataset, insights, and challenges},
author={Mysore, Gautham J},
journal={IEEE Signal Processing Letters},
volume={22},
number={8},
pages={1006--1010},
year={2014},
publisher={IEEE}
}
@article{kumar2023torchaudio,
title={TorchAudio-Squim: Reference-less Speech Quality and Intelligibility measures in TorchAudio},
author={Kumar, Anurag and Tan, Ke and Ni, Zhaoheng and Manocha, Pranay and Zhang, Xiaohui and Henderson, Ethan and Xu, Buye},
journal={arXiv preprint arXiv:2304.01448},
year={2023}
}
@incollection{45611,
title = {CNN Architectures for Large-Scale Audio Classification},
author = {Shawn Hershey and Sourish Chaudhuri and Daniel P. W. Ellis and Jort F. Gemmeke and Aren Jansen and Channing Moore and Manoj Plakal and Devin Platt and Rif A. Saurous and Bryan Seybold and Malcolm Slaney and Ron Weiss and Kevin Wilson},
year = {2017},
URL = {https://arxiv.org/abs/1609.09430},
booktitle = {International Conference on Acoustics, Speech and Signal Processing (ICASSP)}
}
@misc{pratap2023scaling,
title={Scaling Speech Technology to 1,000+ Languages},
author={Vineel Pratap and Andros Tjandra and Bowen Shi and Paden Tomasello and Arun Babu and Sayani Kundu and Ali Elkahky and Zhaoheng Ni and Apoorv Vyas and Maryam Fazel-Zarandi and Alexei Baevski and Yossi Adi and Xiaohui Zhang and Wei-Ning Hsu and Alexis Conneau and Michael Auli},
year={2023},
eprint={2305.13516},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
@article{dowson1982frechet,
title={The Fr{\'e}chet distance between multivariate normal distributions},
author={Dowson, DC and Landau, BV666017},
journal={Journal of multivariate analysis},
volume={12},
number={3},
pages={450--455},
year={1982},
publisher={Elsevier}
}
......@@ -26,6 +26,4 @@ Utilities
:toctree: generated
:nosignatures:
init_sox_effects
shutdown_sox_effects
effect_names
torchaudio
==========
I/O functionalities
~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torchaudio
Audio I/O functions are implemented in :ref:`torchaudio.backend<backend>` module, but for the ease of use, the following functions are made available on :mod:`torchaudio` module. There are different backends available and you can switch backends with :func:`set_audio_backend`.
I/O
---
``torchaudio`` top-level module provides the following functions that make
it easy to handle audio data.
Please refer to :ref:`backend` for the detail, and the :doc:`Audio I/O tutorial <../tutorials/audio_io_tutorial>` for the usage.
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/io.rst
.. function:: torchaudio.info(filepath: str, ...)
info
load
save
Fetch meta data of an audio file. Refer to :ref:`backend` for the detail.
.. _backend:
.. function:: torchaudio.load(filepath: str, ...)
Backend and Dispatcher
----------------------
Load audio file into torch.Tensor object. Refer to :ref:`backend` for the detail.
Decoding and encoding media is highly elaborated process. Therefore, TorchAudio
relies on third party libraries to perform these operations. These third party
libraries are called ``backend``, and currently TorchAudio integrates the
following libraries.
.. function:: torchaudio.save(filepath: str, src: torch.Tensor, sample_rate: int, ...)
Please refer to `Installation <./installation.html>`__ for how to enable backends.
Save torch.Tensor object into an audio format. Refer to :ref:`backend` for the detail.
Conventionally, TorchAudio has had its I/O backend set globally at runtime
based on availability. However, this approach does not allow applications to
use different backends, and it is not well-suited for large codebases.
.. currentmodule:: torchaudio
For these reasons, in v2.0, we introduced a dispatcher, a new mechanism to allow
users to choose a backend for each function call.
When dispatcher mode is enabled, all the I/O functions accept extra keyward argument
``backend``, which specifies the desired backend. If the specified
backend is not available, the function call will fail.
If a backend is not explicitly chosen, the functions will select a backend to use given order of precedence and library availability.
The following table summarizes the backends.
.. list-table::
:header-rows: 1
:widths: 8 12 25 60
* - Priority
- Backend
- Supported OS
- Note
* - 1
- FFmpeg
- Linux, macOS, Windows
- Use :py:func:`~torchaudio.utils.ffmpeg_utils.get_audio_decoders` and
:py:func:`~torchaudio.utils.ffmpeg_utils.get_audio_encoders`
to retrieve the supported codecs.
This backend Supports various protocols, such as HTTPS and MP4, and file-like objects.
* - 2
- SoX
- Linux, macOS
- Use :py:func:`~torchaudio.utils.sox_utils.list_read_formats` and
:py:func:`~torchaudio.utils.sox_utils.list_write_formats`
to retrieve the supported codecs.
This backend does *not* support file-like objects.
* - 3
- SoundFile
- Linux, macOS, Windows
- Please refer to `the official document <https://pysoundfile.readthedocs.io/>`__ for the supported codecs.
This backend supports file-like objects.
.. _dispatcher_migration:
Dispatcher Migration
~~~~~~~~~~~~~~~~~~~~
We are migrating the I/O functions to use the dispatcher mechanism, and this
incurs multiple changes, some of which involve backward-compatibility-breaking
changes, and require users to change their function call.
The (planned) changes are as follows. For up-to-date information,
please refer to https://github.com/pytorch/audio/issues/2950
* In 2.0, audio I/O backend dispatcher was introduced.
Users can opt-in to using dispatcher by setting the environment variable
``TORCHAUDIO_USE_BACKEND_DISPATCHER=1``.
* In 2.1, the disptcher becomes the default mechanism for I/O.
Those who need to keep using the previous mechanism (global backend) can do
so by setting ``TORCHAUDIO_USE_BACKEND_DISPATCHER=0``.
* In 2.2, the legacy global backend mechanism will be removed.
Utility functions :py:func:`get_audio_backend` and :py:func:`set_audio_backend`
become no-op.
Furthermore, we are removing file-like object support from libsox backend, as this
is better supported by FFmpeg backend and makes the build process simpler.
Therefore, beginning with 2.1, FFmpeg and Soundfile are the sole backends that support
file-like objects.
Backend Utilities
~~~~~~~~~~~~~~~~~
-----------------
The following functions are effective only when backend dispatcher is disabled.
.. autofunction:: list_audio_backends
Note that the changes in 2.1 marks :py:func:`get_audio_backend` and
:py:func:`set_audio_backend` deprecated.
.. autofunction:: get_audio_backend
.. autosummary::
:toctree: generated
:nosignatures:
.. autofunction:: set_audio_backend
list_audio_backends
get_audio_backend
set_audio_backend
......@@ -83,14 +83,19 @@ Utility
:nosignatures:
AmplitudeToDB
MelScale
InverseMelScale
MuLawEncoding
MuLawDecoding
Resample
Fade
Vol
Loudness
AddNoise
Convolve
FFTConvolve
Speed
SpeedPerturbation
Deemphasis
Preemphasis
Feature Extractions
-------------------
......@@ -101,6 +106,8 @@ Feature Extractions
Spectrogram
InverseSpectrogram
MelScale
InverseMelScale
MelSpectrogram
GriffinLim
MFCC
......
......@@ -15,7 +15,7 @@ This directory contains sample implementations of training and evaluation pipeli
### Pipeline Demo
[`pipeline_demo.py`](./pipeline_demo.py) demonstrates how to use the `EMFORMER_RNNT_BASE_LIBRISPEECH`
bundle that wraps a pre-trained Emformer RNN-T produced by the LibriSpeech recipe below to perform streaming and full-context ASR on several audio samples.
or `EMFORMER_RNNT_BASE_TEDLIUM3` bundle that wraps a pre-trained Emformer RNN-T produced by the corresponding recipe below to perform streaming and full-context ASR on several audio samples.
## Model Types
......@@ -67,6 +67,8 @@ The table below contains WER results for dev and test subsets of TED-LIUM releas
| dev | 0.108 |
| test | 0.098 |
[`tedlium3/eval_pipeline.py`](./tedlium3/eval_pipeline.py) evaluates the pre-trained `EMFORMER_RNNT_BASE_TEDLIUM3` bundle on the dev and test sets of TED-LIUM release 3. Running the script should produce WER results that are identical to those in the above table.
### MuST-C release v2.0
The MuST-C model is configured with a vocabulary size of 500. Consequently, the MuST-C model's last linear layer in the joiner has an output dimension of 501 (500 + 1 to account for the blank symbol). In contrast to those of the datasets for the above two models, MuST-C's transcripts are cased and punctuated; we preserve the casing and punctuation when training the SentencePiece model.
......
......@@ -19,7 +19,7 @@ from common import (
piecewise_linear_log,
spectrogram_transform,
)
from must.dataset import MUSTC
from mustc.dataset import MUSTC
logger = logging.getLogger()
......
......@@ -13,8 +13,10 @@ from typing import Callable
import torch
import torchaudio
from common import MODEL_TYPE_LIBRISPEECH
from common import MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_MUSTC, MODEL_TYPE_TEDLIUM3
from mustc.dataset import MUSTC
from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
logger = logging.getLogger(__name__)
......@@ -30,6 +32,14 @@ _CONFIGS = {
partial(torchaudio.datasets.LIBRISPEECH, url="test-clean"),
EMFORMER_RNNT_BASE_LIBRISPEECH,
),
MODEL_TYPE_MUSTC: Config(
partial(MUSTC, subset="tst-COMMON"),
EMFORMER_RNNT_BASE_MUSTC,
),
MODEL_TYPE_TEDLIUM3: Config(
partial(torchaudio.datasets.TEDLIUM, release="release3", subset="test"),
EMFORMER_RNNT_BASE_TEDLIUM3,
),
}
......@@ -55,9 +65,9 @@ def run_eval_streaming(args):
with torch.no_grad():
features, length = streaming_feature_extractor(segment)
hypos, state = decoder.infer(features, length, 10, state=state, hypothesis=hypothesis)
hypothesis = hypos[0]
transcript = token_processor(hypothesis[0], lstrip=False)
print(transcript, end="", flush=True)
hypothesis = hypos
transcript = token_processor(hypos[0][0], lstrip=True)
print(transcript, end="\r", flush=True)
print()
# Non-streaming decode.
......
#!/usr/bin/env python3
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import torch
import torchaudio
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_TEDLIUM3
logger = logging.getLogger(__name__)
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def _eval_subset(tedlium_path, subset, feature_extractor, decoder, token_processor, use_cuda):
total_edit_distance = 0
total_length = 0
if subset == "dev":
dataset = torchaudio.datasets.TEDLIUM(tedlium_path, release="release3", subset="dev")
elif subset == "test":
dataset = torchaudio.datasets.TEDLIUM(tedlium_path, release="release3", subset="test")
with torch.no_grad():
for idx in range(len(dataset)):
sample = dataset[idx]
waveform = sample[0].squeeze()
if use_cuda:
waveform = waveform.to(device="cuda")
actual = sample[2].replace("\n", "")
if actual == "ignore_time_segment_in_scoring":
continue
features, length = feature_extractor(waveform)
hypos = decoder(features, length, 20)
hypothesis = hypos[0]
hypothesis = token_processor(hypothesis[0])
total_edit_distance += compute_word_level_distance(actual, hypothesis)
total_length += len(actual.split())
if idx % 100 == 0:
print(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
print(f"Final WER for {subset} set: {total_edit_distance / total_length}")
def run_eval_pipeline(args):
decoder = EMFORMER_RNNT_BASE_TEDLIUM3.get_decoder()
token_processor = EMFORMER_RNNT_BASE_TEDLIUM3.get_token_processor()
feature_extractor = EMFORMER_RNNT_BASE_TEDLIUM3.get_feature_extractor()
if args.use_cuda:
feature_extractor = feature_extractor.to(device="cuda").eval()
decoder = decoder.to(device="cuda")
_eval_subset(args.tedlium_path, "dev", feature_extractor, decoder, token_processor, args.use_cuda)
_eval_subset(args.tedlium_path, "test", feature_extractor, decoder, token_processor, args.use_cuda)
def _parse_args():
parser = ArgumentParser(
description=__doc__,
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--tedlium-path",
type=pathlib.Path,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
run_eval_pipeline(args)
if __name__ == "__main__":
cli_main()
{
"mean": [
14.762723922729492,
16.020633697509766,
16.911531448364258,
16.80994415283203,
18.72406005859375,
18.84550666809082,
19.021404266357422,
19.623443603515625,
19.403806686401367,
19.52766990661621,
19.253433227539062,
19.211227416992188,
19.216045379638672,
19.315574645996094,
19.267532348632812,
19.146976470947266,
18.98181915283203,
18.81462287902832,
18.67916488647461,
18.5198917388916,
18.360441207885742,
18.18699836730957,
18.008447647094727,
17.82094955444336,
17.644861221313477,
17.51972007751465,
17.51348876953125,
17.171707153320312,
17.070415496826172,
17.21990394592285,
16.868940353393555,
17.048307418823242,
16.894960403442383,
17.04732322692871,
16.955705642700195,
17.053966522216797,
17.037548065185547,
17.03425407409668,
17.03618621826172,
16.979724884033203,
16.889690399169922,
16.779285430908203,
16.689767837524414,
16.62590789794922,
16.600360870361328,
16.610321044921875,
16.692338943481445,
16.61323356628418,
16.638328552246094,
16.494739532470703,
16.42980194091797,
16.23759651184082,
16.144210815429688,
16.018585205078125,
15.985218048095703,
15.947102546691895,
15.894798278808594,
15.832999229431152,
15.704426765441895,
15.538087844848633,
15.378302574157715,
15.19461441040039,
15.00456714630127,
14.861663818359375,
14.676336288452148,
14.594626426696777,
14.561753273010254,
14.464197158813477,
14.43082046508789,
14.388801574707031,
14.257562637329102,
14.231459617614746,
14.19768238067627,
14.123900413513184,
14.159867286682129,
14.059795379638672,
13.968880653381348,
13.927794456481934,
13.645783424377441,
12.086114883422852
],
"invstddev": [
0.3553205132484436,
0.3363242745399475,
0.3194723129272461,
0.3199574947357178,
0.28755369782447815,
0.2879481613636017,
0.27939942479133606,
0.27543479204177856,
0.2806696891784668,
0.28141146898269653,
0.2753477990627289,
0.274241179227829,
0.27815768122673035,
0.27794352173805237,
0.2763032615184784,
0.2744459807872772,
0.27375343441963196,
0.27415215969085693,
0.27628427743911743,
0.27667510509490967,
0.2806207835674286,
0.28371962904930115,
0.2893684506416321,
0.2944427728652954,
0.2989389896392822,
0.30326008796691895,
0.30760079622268677,
0.3089521527290344,
0.3105863034725189,
0.31274259090423584,
0.31318506598472595,
0.3154853880405426,
0.3167822062969208,
0.3182784914970398,
0.31875282526016235,
0.3185810148715973,
0.31908345222473145,
0.3207632303237915,
0.32282087206840515,
0.3241617977619171,
0.3260948061943054,
0.32735878229141235,
0.32947203516960144,
0.33052706718444824,
0.3309975266456604,
0.3301711678504944,
0.32793518900871277,
0.3252142369747162,
0.32336947321891785,
0.32320502400398254,
0.3264254927635193,
0.32860180735588074,
0.3322647213935852,
0.3100382685661316,
0.3216720223426819,
0.32280418276786804,
0.32710719108581543,
0.3284962773323059,
0.3319654166698456,
0.32880258560180664,
0.33075764775276184,
0.32947179675102234,
0.32880640029907227,
0.3296009302139282,
0.324250727891922,
0.3247823715209961,
0.328702837228775,
0.32418182492256165,
0.3247915208339691,
0.3251509964466095,
0.31811773777008057,
0.3195462226867676,
0.3187839686870575,
0.31459841132164,
0.32190003991127014,
0.3193890154361725,
0.315574049949646,
0.317360520362854,
0.3075887858867645,
0.3034747838973999
]
}
import os
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
from common import (
Batch,
batch_by_token_count,
FunctionalModule,
GlobalStatsNormalization,
piecewise_linear_log,
post_process_hypos,
spectrogram_transform,
WarmupLR,
)
from pytorch_lightning import LightningModule
from torchaudio.models import emformer_rnnt_base, RNNTBeamSearch
class CustomDataset(torch.utils.data.Dataset):
r"""Sort TEDLIUM3 samples by target length and batch to max durations."""
def __init__(self, base_dataset, max_token_limit):
super().__init__()
self.base_dataset = base_dataset
idx_target_lengths = [
(idx, self._target_length(fileid, line)) for idx, (fileid, line) in enumerate(self.base_dataset._filelist)
]
idx_target_lengths = [(idx, length) for idx, length in idx_target_lengths if length != -1]
assert len(idx_target_lengths) > 0
idx_target_lengths = sorted(idx_target_lengths, key=lambda x: x[1])
assert max_token_limit >= idx_target_lengths[-1][1]
self.batches = batch_by_token_count(idx_target_lengths, max_token_limit)
def _target_length(self, fileid, line):
transcript_path = os.path.join(self.base_dataset._path, "stm", fileid)
with open(transcript_path + ".stm") as f:
transcript = f.readlines()[line]
_, _, _, start_time, end_time, _, transcript = transcript.split(" ", 6)
if transcript.lower() == "ignore_time_segment_in_scoring\n":
return -1
else:
return float(end_time) - float(start_time)
def __getitem__(self, idx):
return [self.base_dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class EvalDataset(torch.utils.data.IterableDataset):
def __init__(self, base_dataset):
super().__init__()
self.base_dataset = base_dataset
def __iter__(self):
for sample in iter(self.base_dataset):
actual = sample[2].replace("\n", "")
if actual == "ignore_time_segment_in_scoring":
continue
yield sample
class TEDLIUM3RNNTModule(LightningModule):
def __init__(
self,
*,
tedlium_path: str,
sp_model_path: str,
global_stats_path: str,
):
super().__init__()
self.model = emformer_rnnt_base(num_symbols=501)
self.loss = torchaudio.transforms.RNNTLoss(reduction="mean", clamp=1.0)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
self.tedlium_path = tedlium_path
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.blank_idx = self.sp_model.get_piece_size()
def _extract_labels(self, samples: List):
"""Convert text transcript into int labels.
Note:
There are ``<unk>`` tokens in the training set that are regarded as normal tokens
by the SentencePiece model. This will impact RNNT decoding since the decoding result
of ``<unk>`` will be ``?? unk ??`` and will not be excluded from the final prediction.
To address it, here we replace ``<unk>`` with ``<garbage>`` and set
``user_defined_symbols=["<garbage>"]`` in the SentencePiece model training.
Then we map the index of ``<garbage>`` to the real ``unknown`` index.
"""
targets = [
self.sp_model.encode(sample[2].lower().replace("<unk>", "<garbage>").replace("\n", ""))
for sample in samples
]
targets = [
[ele if ele != 4 else self.sp_model.unk_id() for ele in target] for target in targets
] # map id of <unk> token to unk_id
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _train_extract_features(self, samples: List):
mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.train_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _valid_extract_features(self, samples: List):
mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = self.valid_data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
def _train_collate_fn(self, samples: List):
features, feature_lengths = self._train_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _valid_collate_fn(self, samples: List):
features, feature_lengths = self._valid_extract_features(samples)
targets, target_lengths = self._extract_labels(samples)
return Batch(features, feature_lengths, targets, target_lengths)
def _test_collate_fn(self, samples: List):
return self._valid_collate_fn(samples), [sample[2] for sample in samples]
def _step(self, batch, batch_idx, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
output, src_lengths, _, _ = self.model(
batch.features,
batch.feature_lengths,
prepended_targets,
prepended_target_lengths,
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return (
[self.optimizer],
[
{"scheduler": self.warmup_lr_scheduler, "interval": "step"},
],
)
def forward(self, batch: Batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
return self._step(batch, batch_idx, "train")
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch_tuple, batch_idx):
return self._step(batch_tuple[0], batch_idx, "test")
def train_dataloader(self):
dataset = CustomDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="train"), 100)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._train_collate_fn,
num_workers=10,
shuffle=True,
)
return dataloader
def val_dataloader(self):
dataset = CustomDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="dev"), 100)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=None,
collate_fn=self._valid_collate_fn,
num_workers=10,
)
return dataloader
def test_dataloader(self):
dataset = EvalDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="test"))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
def dev_dataloader(self):
dataset = EvalDataset(torchaudio.datasets.TEDLIUM(self.tedlium_path, release="release3", subset="dev"))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
return dataloader
#!/usr/bin/env python3
"""Train the SentencePiece model by using the transcripts of TED-LIUM release 3 training set.
Example:
python train_spm.py --tedlium-path /home/datasets/
"""
import io
import logging
import os
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
logger = logging.getLogger(__name__)
def _parse_args():
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--tedlium-path",
required=True,
type=pathlib.Path,
help="Path to TED-LIUM release 3 dataset.",
)
parser.add_argument(
"--output-file",
default=pathlib.Path("./spm_bpe_500.model"),
type=pathlib.Path,
help="File to save model to. (Default: './spm_bpe_500.model')",
)
parser.add_argument("--debug", action="store_true", help="whether to use debug level for logging")
return parser.parse_args()
def _extract_train_text(tedlium_path, output_dir):
stm_path = tedlium_path / "TEDLIUM_release-3/data/stm/"
transcripts = []
for file in sorted(os.listdir(stm_path)):
if file.endswith(".stm"):
file = os.path.join(stm_path, file)
with open(file) as f:
for line in f.readlines():
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = line.split(" ", 6)
if transcript == "ignore_time_segment_in_scoring\n":
continue
else:
transcript = transcript.replace("<unk>", "<garbage>").replace("\n", "")
transcripts.append(transcript)
return transcripts
def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
vocab_size=500,
model_type="bpe",
input_sentence_size=-1,
character_coverage=1.0,
user_defined_symbols=["<garbage>"],
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def _init_logger(debug):
fmt = "%(asctime)s %(message)s" if debug else "%(message)s"
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(format=fmt, level=level, datefmt="%Y-%m-%d %H:%M:%S")
def cli_main():
args = _parse_args()
_init_logger(args.debug)
transcripts = _extract_train_text(args.tedlium_path, args.output_dir)
model = train_spm(transcripts)
with open(args.output_file, "wb") as f:
f.write(model)
logger.info("Successfully trained the sentencepiece model")
if __name__ == "__main__":
cli_main()
# Conformer RNN-T ASR Example
This directory contains sample implementations of training and evaluation pipelines for a Conformer RNN-T ASR model.
## Setup
### Install PyTorch and TorchAudio nightly or from source
Because Conformer RNN-T is currently a prototype feature, you will need to either use the TorchAudio nightly build or build TorchAudio from source. Note also that GPU support is required for training.
To install the nightly, follow the directions at <https://pytorch.org/>.
To build TorchAudio from source, refer to the [contributing guidelines](https://github.com/pytorch/audio/blob/main/CONTRIBUTING.md).
### Install additional dependencies
```bash
pip install pytorch-lightning sentencepiece tensorboard
```
## Usage
### Training
[`train.py`](./train.py) trains an Conformer RNN-T model (30.2M parameters, 121MB) on LibriSpeech using PyTorch Lightning. Note that the script expects users to have the following:
- Access to GPU nodes for training.
- Full LibriSpeech dataset.
- SentencePiece model to be used to encode targets; the model can be generated using [`train_spm.py`](./train_spm.py).
- File (--global_stats_path) that contains training set feature statistics; this file can be generated using [`global_stats.py`](../emformer_rnnt/global_stats.py).
Sample SLURM command:
```
srun --cpus-per-task=12 --gpus-per-node=8 -N 4 --ntasks-per-node=8 python train.py --exp-dir ./experiments --librispeech-path ./librispeech/ --global-stats-path ./global_stats.json --sp-model-path ./spm_unigram_1023.model --epochs 160
```
### Evaluation
[`eval.py`](./eval.py) evaluates a trained Conformer RNN-T model on LibriSpeech test-clean.
Sample SLURM command:
```
srun python eval.py --checkpoint-path ./experiments/checkpoints/epoch=159.ckpt --librispeech-path ./librispeech/ --sp-model-path ./spm_unigram_1023.model --use-cuda
```
The table below contains WER results for various splits.
| | WER |
|:-------------------:|-------------:|
| test-clean | 0.0310 |
| test-other | 0.0805 |
| dev-clean | 0.0314 |
| dev-other | 0.0827 |
import os
import random
import torch
import torchaudio
from pytorch_lightning import LightningDataModule
def _batch_by_token_count(idx_target_lengths, max_tokens, batch_size=None):
batches = []
current_batch = []
current_token_count = 0
for idx, target_length in idx_target_lengths:
if current_token_count + target_length > max_tokens or (batch_size and len(current_batch) == batch_size):
batches.append(current_batch)
current_batch = [idx]
current_token_count = target_length
else:
current_batch.append(idx)
current_token_count += target_length
if current_batch:
batches.append(current_batch)
return batches
def get_sample_lengths(librispeech_dataset):
fileid_to_target_length = {}
def _target_length(fileid):
if fileid not in fileid_to_target_length:
speaker_id, chapter_id, _ = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + librispeech_dataset._ext_txt
file_text = os.path.join(librispeech_dataset._path, speaker_id, chapter_id, file_text)
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
fileid_to_target_length[fileid_text] = len(transcript)
return fileid_to_target_length[fileid]
return [_target_length(fileid) for fileid in librispeech_dataset._walker]
class CustomBucketDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset,
lengths,
max_tokens,
num_buckets,
shuffle=False,
batch_size=None,
):
super().__init__()
assert len(dataset) == len(lengths)
self.dataset = dataset
max_length = max(lengths)
min_length = min(lengths)
assert max_tokens >= max_length
buckets = torch.linspace(min_length, max_length, num_buckets)
lengths = torch.tensor(lengths)
bucket_assignments = torch.bucketize(lengths, buckets)
idx_length_buckets = [(idx, length, bucket_assignments[idx]) for idx, length in enumerate(lengths)]
if shuffle:
idx_length_buckets = random.sample(idx_length_buckets, len(idx_length_buckets))
else:
idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[1], reverse=True)
sorted_idx_length_buckets = sorted(idx_length_buckets, key=lambda x: x[2])
self.batches = _batch_by_token_count(
[(idx, length) for idx, length, _ in sorted_idx_length_buckets],
max_tokens,
batch_size=batch_size,
)
def __getitem__(self, idx):
return [self.dataset[subidx] for subidx in self.batches[idx]]
def __len__(self):
return len(self.batches)
class TransformDataset(torch.utils.data.Dataset):
def __init__(self, dataset, transform_fn):
self.dataset = dataset
self.transform_fn = transform_fn
def __getitem__(self, idx):
return self.transform_fn(self.dataset[idx])
def __len__(self):
return len(self.dataset)
class LibriSpeechDataModule(LightningDataModule):
librispeech_cls = torchaudio.datasets.LIBRISPEECH
def __init__(
self,
*,
librispeech_path,
train_transform,
val_transform,
test_transform,
max_tokens=700,
batch_size=2,
train_num_buckets=50,
train_shuffle=True,
num_workers=10,
):
super().__init__()
self.librispeech_path = librispeech_path
self.train_dataset_lengths = None
self.val_dataset_lengths = None
self.train_transform = train_transform
self.val_transform = val_transform
self.test_transform = test_transform
self.max_tokens = max_tokens
self.batch_size = batch_size
self.train_num_buckets = train_num_buckets
self.train_shuffle = train_shuffle
self.num_workers = num_workers
def train_dataloader(self):
datasets = [
self.librispeech_cls(self.librispeech_path, url="train-clean-360"),
self.librispeech_cls(self.librispeech_path, url="train-clean-100"),
self.librispeech_cls(self.librispeech_path, url="train-other-500"),
]
if not self.train_dataset_lengths:
self.train_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
self.train_num_buckets,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.train_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.train_transform)
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=self.num_workers,
batch_size=None,
shuffle=self.train_shuffle,
)
return dataloader
def val_dataloader(self):
datasets = [
self.librispeech_cls(self.librispeech_path, url="dev-clean"),
self.librispeech_cls(self.librispeech_path, url="dev-other"),
]
if not self.val_dataset_lengths:
self.val_dataset_lengths = [get_sample_lengths(dataset) for dataset in datasets]
dataset = torch.utils.data.ConcatDataset(
[
CustomBucketDataset(
dataset,
lengths,
self.max_tokens,
1,
batch_size=self.batch_size,
)
for dataset, lengths in zip(datasets, self.val_dataset_lengths)
]
)
dataset = TransformDataset(dataset, self.val_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, num_workers=self.num_workers)
return dataloader
def test_dataloader(self):
dataset = self.librispeech_cls(self.librispeech_path, url="test-clean")
dataset = TransformDataset(dataset, self.test_transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=None)
return dataloader
import logging
import pathlib
from argparse import ArgumentParser
import sentencepiece as spm
import torch
import torchaudio
from lightning import ConformerRNNTModule
from transforms import get_data_module
logger = logging.getLogger()
def compute_word_level_distance(seq1, seq2):
return torchaudio.functional.edit_distance(seq1.lower().split(), seq2.lower().split())
def run_eval(args):
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
model = ConformerRNNTModule.load_from_checkpoint(args.checkpoint_path, sp_model=sp_model).eval()
data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path))
if args.use_cuda:
model = model.to(device="cuda")
total_edit_distance = 0
total_length = 0
dataloader = data_module.test_dataloader()
with torch.no_grad():
for idx, (batch, sample) in enumerate(dataloader):
actual = sample[0][2]
predicted = model(batch)
total_edit_distance += compute_word_level_distance(actual, predicted)
total_length += len(actual.split())
if idx % 100 == 0:
logger.warning(f"Processed elem {idx}; WER: {total_edit_distance / total_length}")
logger.warning(f"Final WER: {total_edit_distance / total_length}")
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
required=True,
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--use-cuda",
action="store_true",
default=False,
help="Run using CUDA.",
)
args = parser.parse_args()
run_eval(args)
if __name__ == "__main__":
cli_main()
{
"mean": [
15.058613777160645,
16.34557342529297,
16.34653663635254,
16.240671157836914,
17.45355224609375,
17.445302963256836,
17.52323341369629,
18.076807022094727,
17.699262619018555,
17.706790924072266,
17.24724578857422,
17.153791427612305,
17.213361740112305,
17.347240447998047,
17.331117630004883,
17.21516227722168,
17.030071258544922,
16.818960189819336,
16.573062896728516,
16.29717254638672,
16.00996971130371,
15.794167518615723,
15.616395950317383,
15.459056854248047,
15.306838989257812,
15.199165344238281,
15.208144187927246,
14.883454322814941,
14.787869453430176,
14.947835922241211,
14.5912504196167,
14.76955509185791,
14.617781639099121,
14.840407371520996,
14.83073616027832,
14.909119606018066,
14.89070987701416,
14.918207168579102,
14.939517974853516,
14.913643836975098,
14.863334655761719,
14.803299903869629,
14.751264572143555,
14.688116073608398,
14.63498306274414,
14.615056037902832,
14.680213928222656,
14.616259574890137,
14.707776069641113,
14.630264282226562,
14.644737243652344,
14.547430038452148,
14.529033660888672,
14.49357795715332,
14.411538124084473,
14.33312702178955,
14.260393142700195,
14.204919815063477,
14.130182266235352,
14.06987476348877,
14.010197639465332,
13.938552856445312,
13.750232696533203,
13.607213973999023,
13.457777976989746,
13.31512451171875,
13.167718887329102,
13.019341468811035,
12.8869047164917,
12.795098304748535,
12.685126304626465,
12.620392799377441,
12.58949089050293,
12.537697792053223,
12.496938705444336,
12.410022735595703,
12.346826553344727,
12.221966743469238,
12.122841835021973,
12.005624771118164
],
"invstddev": [
0.25952333211898804,
0.2590482831001282,
0.24866817891597748,
0.24776232242584229,
0.22200720012187958,
0.21363843977451324,
0.20652402937412262,
0.19909949600696564,
0.2021811604499817,
0.20355898141860962,
0.20546883344650269,
0.2061648815870285,
0.20569036900997162,
0.20412985980510712,
0.20357738435268402,
0.2041499763727188,
0.2055872678756714,
0.20807604491710663,
0.21054454147815704,
0.21341396868228912,
0.21418628096580505,
0.22065168619155884,
0.2248840034008026,
0.22723940014839172,
0.230172261595726,
0.23371541500091553,
0.23734734952449799,
0.23960146307945251,
0.24088498950004578,
0.241532102227211,
0.24218633770942688,
0.24371792376041412,
0.2447739839553833,
0.25564682483673096,
0.2632736265659332,
0.2549223005771637,
0.24608071148395538,
0.2464841604232788,
0.2470586597919464,
0.24785254895687103,
0.24904784560203552,
0.2503036856651306,
0.25226327776908875,
0.2532329559326172,
0.2527913451194763,
0.2518651783466339,
0.2504975199699402,
0.24836081266403198,
0.24765831232070923,
0.24767662584781647,
0.24965286254882812,
0.2501370906829834,
0.2508895993232727,
0.2512582540512085,
0.25150999426841736,
0.2525503635406494,
0.25313329696655273,
0.2534785270690918,
0.25330957770347595,
0.25366073846817017,
0.25502219796180725,
0.2608155608177185,
0.25662899017333984,
0.2558451294898987,
0.25671014189720154,
0.2577403485774994,
0.25914356112480164,
0.2596718966960907,
0.25953933596611023,
0.2610883116722107,
0.26132410764694214,
0.26272818446159363,
0.26397505402565,
0.26440608501434326,
0.26543495059013367,
0.26753780245780945,
0.26935192942619324,
0.26732245087623596,
0.26666897535324097,
0.2663257420063019
]
}
import logging
import math
from collections import namedtuple
from typing import List, Tuple
import sentencepiece as spm
import torch
import torchaudio
from pytorch_lightning import LightningModule
from torchaudio.models import Hypothesis, RNNTBeamSearch
from torchaudio.prototype.models import conformer_rnnt_base
logger = logging.getLogger()
_expected_spm_vocab_size = 1023
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
r"""Learning rate scheduler that performs linear warmup and exponential annealing.
Args:
optimizer (torch.optim.Optimizer): optimizer to use.
warmup_steps (int): number of scheduler steps for which to warm up learning rate.
force_anneal_step (int): scheduler step at which annealing of learning rate begins.
anneal_factor (float): factor to scale base learning rate by at each annealing step.
last_epoch (int, optional): The index of last epoch. (Default: -1)
verbose (bool, optional): If ``True``, prints a message to stdout for
each update. (Default: ``False``)
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: int,
force_anneal_step: int,
anneal_factor: float,
last_epoch=-1,
verbose=False,
):
self.warmup_steps = warmup_steps
self.force_anneal_step = force_anneal_step
self.anneal_factor = anneal_factor
super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
def get_lr(self):
if self._step_count < self.force_anneal_step:
return [(min(1.0, self._step_count / self.warmup_steps)) * base_lr for base_lr in self.base_lrs]
else:
scaling_factor = self.anneal_factor ** (self._step_count - self.force_anneal_step)
return [scaling_factor * base_lr for base_lr in self.base_lrs]
def post_process_hypos(
hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
tokens_idx = 0
score_idx = 3
post_process_remove_list = [
sp_model.unk_id(),
sp_model.eos_id(),
sp_model.pad_id(),
]
filtered_hypo_tokens = [
[token_index for token_index in h[tokens_idx][1:] if token_index not in post_process_remove_list] for h in hypos
]
hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
hypos_ids = [h[tokens_idx][1:] for h in hypos]
hypos_score = [[math.exp(h[score_idx])] for h in hypos]
nbest_batch = list(zip(hypos_str, hypos_score, hypos_ids))
return nbest_batch
class ConformerRNNTModule(LightningModule):
def __init__(self, sp_model):
super().__init__()
self.sp_model = sp_model
spm_vocab_size = self.sp_model.get_piece_size()
assert spm_vocab_size == _expected_spm_vocab_size, (
"The model returned by conformer_rnnt_base expects a SentencePiece model of "
f"vocabulary size {_expected_spm_vocab_size}, but the given SentencePiece model has a vocabulary size "
f"of {spm_vocab_size}. Please provide a correctly configured SentencePiece model."
)
self.blank_idx = spm_vocab_size
# ``conformer_rnnt_base`` hardcodes a specific Conformer RNN-T configuration.
# For greater customizability, please refer to ``conformer_rnnt_model``.
self.model = conformer_rnnt_base()
self.loss = torchaudio.transforms.RNNTLoss(reduction="sum")
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=8e-4, betas=(0.9, 0.98), eps=1e-9)
self.warmup_lr_scheduler = WarmupLR(self.optimizer, 40, 120, 0.96)
def _step(self, batch, _, step_type):
if batch is None:
return None
prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
prepended_targets[:, 1:] = batch.targets
prepended_targets[:, 0] = self.blank_idx
prepended_target_lengths = batch.target_lengths + 1
output, src_lengths, _, _ = self.model(
batch.features,
batch.feature_lengths,
prepended_targets,
prepended_target_lengths,
)
loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
return (
[self.optimizer],
[{"scheduler": self.warmup_lr_scheduler, "interval": "epoch"}],
)
def forward(self, batch: Batch):
decoder = RNNTBeamSearch(self.model, self.blank_idx)
hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
return post_process_hypos(hypotheses, self.sp_model)[0][0]
def training_step(self, batch: Batch, batch_idx):
"""Custom training step.
By default, DDP does the following on each train step:
- For each GPU, compute loss and gradient on shard of training data.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / N, where N is the world
size (total number of GPUs).
- Update parameters on each GPU.
Here, we do the following:
- For k-th GPU, compute loss and scale it by (N / B_total), where B_total is
the sum of batch sizes across all GPUs. Compute gradient from scaled loss.
- Sync and average gradients across all GPUs. The final gradient
is (sum of gradients across all GPUs) / B_total.
- Update parameters on each GPU.
Doing so allows us to account for the variability in batch sizes that
variable-length sequential data yield.
"""
loss = self._step(batch, batch_idx, "train")
batch_size = batch.features.size(0)
batch_sizes = self.all_gather(batch_size)
self.log("Gathered batch size", batch_sizes.sum(), on_step=True, on_epoch=True)
loss *= batch_sizes.size(0) / batch_sizes.sum() # world size / batch size
return loss
def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "val")
def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx, "test")
import pathlib
from argparse import ArgumentParser
import sentencepiece as spm
from lightning import ConformerRNNTModule
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.strategies import DDPStrategy
from transforms import get_data_module
def run_train(args):
seed_everything(1)
checkpoint_dir = args.exp_dir / "checkpoints"
checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/val_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
train_checkpoint = ModelCheckpoint(
checkpoint_dir,
monitor="Losses/train_loss",
mode="min",
save_top_k=5,
save_weights_only=False,
verbose=True,
)
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks = [
checkpoint,
train_checkpoint,
lr_monitor,
]
trainer = Trainer(
default_root_dir=args.exp_dir,
max_epochs=args.epochs,
num_nodes=args.nodes,
devices=args.gpus,
accelerator="gpu",
strategy=DDPStrategy(find_unused_parameters=False),
callbacks=callbacks,
reload_dataloaders_every_n_epochs=1,
gradient_clip_val=10.0,
)
sp_model = spm.SentencePieceProcessor(model_file=str(args.sp_model_path))
model = ConformerRNNTModule(sp_model)
data_module = get_data_module(str(args.librispeech_path), str(args.global_stats_path), str(args.sp_model_path))
trainer.fit(model, data_module, ckpt_path=args.checkpoint_path)
def cli_main():
parser = ArgumentParser()
parser.add_argument(
"--checkpoint-path",
default=None,
type=pathlib.Path,
help="Path to checkpoint to use for evaluation.",
)
parser.add_argument(
"--exp-dir",
default=pathlib.Path("./exp"),
type=pathlib.Path,
help="Directory to save checkpoints and logs to. (Default: './exp')",
)
parser.add_argument(
"--global-stats-path",
default=pathlib.Path("global_stats.json"),
type=pathlib.Path,
help="Path to JSON file containing feature means and stddevs.",
)
parser.add_argument(
"--librispeech-path",
type=pathlib.Path,
help="Path to LibriSpeech datasets.",
required=True,
)
parser.add_argument(
"--sp-model-path",
type=pathlib.Path,
help="Path to SentencePiece model.",
required=True,
)
parser.add_argument(
"--nodes",
default=4,
type=int,
help="Number of nodes to use for training. (Default: 4)",
)
parser.add_argument(
"--gpus",
default=8,
type=int,
help="Number of GPUs per node to use for training. (Default: 8)",
)
parser.add_argument(
"--epochs",
default=120,
type=int,
help="Number of epochs to train for. (Default: 120)",
)
args = parser.parse_args()
run_train(args)
if __name__ == "__main__":
cli_main()
#!/usr/bin/env python3
"""Trains a SentencePiece model on transcripts across LibriSpeech train-clean-100, train-clean-360, and train-other-500.
Example:
python train_spm.py --librispeech-path ./datasets
"""
import io
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter
import sentencepiece as spm
def get_transcript_text(transcript_path):
with open(transcript_path) as f:
return [line.strip().split(" ", 1)[1].lower() for line in f]
def get_transcripts(dataset_path):
transcript_paths = dataset_path.glob("*/*/*.trans.txt")
merged_transcripts = []
for path in transcript_paths:
merged_transcripts += get_transcript_text(path)
return merged_transcripts
def train_spm(input):
model_writer = io.BytesIO()
spm.SentencePieceTrainer.train(
sentence_iterator=iter(input),
model_writer=model_writer,
vocab_size=1023,
model_type="unigram",
input_sentence_size=-1,
character_coverage=1.0,
bos_id=0,
pad_id=1,
eos_id=2,
unk_id=3,
)
return model_writer.getvalue()
def parse_args():
default_output_path = "./spm_unigram_1023.model"
parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--librispeech-path",
required=True,
type=pathlib.Path,
help="Path to LibriSpeech dataset.",
)
parser.add_argument(
"--output-file",
default=pathlib.Path(default_output_path),
type=pathlib.Path,
help=f"File to save model to. (Default: '{default_output_path}')",
)
return parser.parse_args()
def run_cli():
args = parse_args()
root = args.librispeech_path / "LibriSpeech"
splits = ["train-clean-100", "train-clean-360", "train-other-500"]
merged_transcripts = []
for split in splits:
path = pathlib.Path(root) / split
merged_transcripts += get_transcripts(path)
model = train_spm(merged_transcripts)
with open(args.output_file, "wb") as f:
f.write(model)
if __name__ == "__main__":
run_cli()
import json
import math
from functools import partial
from typing import List
import sentencepiece as spm
import torch
import torchaudio
from data_module import LibriSpeechDataModule
from lightning import Batch
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
_spectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=400, n_mels=80, hop_length=160)
def _piecewise_linear_log(x):
x = x * _gain
x[x > math.e] = torch.log(x[x > math.e])
x[x <= math.e] = x[x <= math.e] / math.e
return x
class FunctionalModule(torch.nn.Module):
def __init__(self, functional):
super().__init__()
self.functional = functional
def forward(self, input):
return self.functional(input)
class GlobalStatsNormalization(torch.nn.Module):
def __init__(self, global_stats_path):
super().__init__()
with open(global_stats_path) as f:
blob = json.loads(f.read())
self.mean = torch.tensor(blob["mean"])
self.invstddev = torch.tensor(blob["invstddev"])
def forward(self, input):
return (input - self.mean) * self.invstddev
def _extract_labels(sp_model, samples: List):
targets = [sp_model.encode(sample[2].lower()) for sample in samples]
lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
targets = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(elem) for elem in targets],
batch_first=True,
padding_value=1.0,
).to(dtype=torch.int32)
return targets, lengths
def _extract_features(data_pipeline, samples: List):
mel_features = [_spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
features = data_pipeline(features)
lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
return features, lengths
class TrainTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.train_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.FrequencyMasking(27),
torchaudio.transforms.TimeMasking(100, p=0.2),
torchaudio.transforms.TimeMasking(100, p=0.2),
FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
)
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.train_data_pipeline, samples)
targets, target_lengths = _extract_labels(self.sp_model, samples)
return Batch(features, feature_lengths, targets, target_lengths)
class ValTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
self.valid_data_pipeline = torch.nn.Sequential(
FunctionalModule(_piecewise_linear_log),
GlobalStatsNormalization(global_stats_path),
)
def __call__(self, samples: List):
features, feature_lengths = _extract_features(self.valid_data_pipeline, samples)
targets, target_lengths = _extract_labels(self.sp_model, samples)
return Batch(features, feature_lengths, targets, target_lengths)
class TestTransform:
def __init__(self, global_stats_path: str, sp_model_path: str):
self.val_transforms = ValTransform(global_stats_path, sp_model_path)
def __call__(self, sample):
return self.val_transforms([sample]), [sample]
def get_data_module(librispeech_path, global_stats_path, sp_model_path):
train_transform = TrainTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
val_transform = ValTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
test_transform = TestTransform(global_stats_path=global_stats_path, sp_model_path=sp_model_path)
return LibriSpeechDataModule(
librispeech_path=librispeech_path,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment