"...git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "38833ceb6295d021b2da640cf693ebad45f3d65b"
Unverified Commit d26b37e7 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

Speech2TextTransformer (#10175)



* s2t

* fix config

* conversion script

* fix import

* add tokenizer

* fix tok init

* fix tokenizer

* first version working

* fix embeds

* fix lm head

* remove extra heads

* fix convert script

* handle encoder attn mask

* style

* better enc attn mask

* override _prepare_attention_mask_for_generation

* handle attn_maks in encoder and decoder

* input_ids => input_features

* enable use_cache

* remove old code

* expand embeddings if needed

* remove logits bias

* masked_lm_loss => loss

* hack tokenizer to support feature processing

* fix model_input_names

* style

* fix error message

* doc

* remove inputs_embeds

* remove input_embeds

* remove unnecessary docstring

* quality

* SpeechToText => Speech2Text

* style

* remove shared_embeds

* subsample => conv

* remove Speech2TextTransformerDecoderWrapper

* update output_lengths formula

* fix table

* remove max_position_embeddings

* update conversion scripts

* add possibility to do upper case for now

* add FeatureExtractor and Processor

* add tests for extractor

* require_torch_audio => require_torchaudio

* add processor test

* update import

* remove classification head

* attention mask is now 1D

* update docstrings

* attention mask should be of type long

* handle attention mask from generate

* alwyas return attention_mask

* fix test

* style

* doc

* Speech2TextTransformer => Speech2Text

* Speech2TextTransformerConfig => Speech2TextConfig

* remove dummy_inputs

* nit

* style

* multilinguial tok

* fix tokenizer

* add tgt_lang setter

* save lang_codes

* fix tokenizer

* add forced_bos_token_id to tokenizer

* apply review suggestions

* add torchaudio to extra deps

* add speech deps to CI

* fix dep

* add libsndfile to ci

* libsndfile1

* add speech to extras all

* libsndfile1 -> libsndfile1

* libsndfile

* libsndfile1-dev

* apt update

* add sudo to install

* update deps table

* install libsndfile1-dev on CI

* tuple to list

* init conv layer

* add model tests

* quality

* add integration tests

* skip_special_tokens

* add speech_to_text_transformer in toctree

* fix tokenizer

* fix fp16 tests

* add tokenizer tests

* fix copyright

* input_values => input_features

* doc

* add model in readme

* doc

* change checkpoint names

* fix copyright

* fix code example

* add max_model_input_sizes in tokenizer

* fix integration tests

* add do_lower_case to tokenizer

* remove clamp trick

* fix "Add modeling imports here"

* fix copyrights

* fix tests

* SpeechToTextTransformer => SpeechToText

* fix naming

* fix table formatting

* fix typo

* style

* fix typos

* remove speech dep from extras[testing]

* fix copies

* rename doc file,

* put imports under is_torch_available

* run feat extract tests when torch is available

* dummy objects for processor and extractor

* fix imports in tests

* fix import in modeling test

* fxi imports

* fix torch import

* fix imports again

* fix positional embeddings

* fix typo in import

* adapt new extractor refactor

* style

* fix torchscript test

* doc

* doc

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix docs, copied from, style

* fix docstring

* handle imports

* remove speech from all extra deps

* remove s2t from seq2seq lm mapping

* better names

* skip training tests

* add install instructions

* List => Tuple

* doc

* fix conversion script

* fix urls

* add instruction for libsndfile

* fix fp16 test
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent efb5c0a4
......@@ -77,8 +77,9 @@ jobs:
keys:
- v0.4-torch_and_tf-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece]
- run: pip install .[sklearn,tf-cpu,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-{{ checksum "setup.py" }}
......@@ -104,8 +105,9 @@ jobs:
keys:
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece]
- run: pip install .[sklearn,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
......@@ -157,8 +159,9 @@ jobs:
keys:
- v0.4-flax-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece]
- run: sudo pip install .[flax,sklearn,torch,testing,sentencepiece,speech]
- save_cache:
key: v0.4-flax-{{ checksum "setup.py" }}
paths:
......@@ -183,8 +186,9 @@ jobs:
keys:
- v0.4-torch-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install .[sklearn,torch,testing,sentencepiece]
- run: pip install .[sklearn,torch,testing,sentencepiece,speech]
- run: pip install tapas torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cpu.html
- save_cache:
key: v0.4-torch-{{ checksum "setup.py" }}
......@@ -300,6 +304,7 @@ jobs:
keys:
- v0.4-build_doc-{{ checksum "setup.py" }}
- v0.4-{{ checksum "setup.py" }}
- run: sudo apt-get -y update && sudo apt-get install -y libsndfile1-dev
- run: pip install --upgrade pip
- run: pip install ."[all, docs]"
- save_cache:
......
......@@ -227,6 +227,7 @@ Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, and Wen-tau Yih.
1. **[ProphetNet](https://huggingface.co/transformers/model_doc/prophetnet.html)** (from Microsoft Research) released with the paper [ProphetNet: Predicting Future N-gram for Sequence-to-Sequence Pre-training](https://arxiv.org/abs/2001.04063) by Yu Yan, Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
1. **[Reformer](https://huggingface.co/transformers/model_doc/reformer.html)** (from Google Research) released with the paper [Reformer: The Efficient Transformer](https://arxiv.org/abs/2001.04451) by Nikita Kitaev, Łukasz Kaiser, Anselm Levskaya.
1. **[RoBERTa](https://huggingface.co/transformers/model_doc/roberta.html)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[SpeechToTextTransformer](https://huggingface.co/transformers/model_doc/speech_to_text.html)** (from Facebook), released together with the paper [fairseq S2T: Fast Speech-to-Text Modeling with fairseq](https://arxiv.org/abs/2010.05171) by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
1. **[SqueezeBert](https://huggingface.co/transformers/model_doc/squeezebert.html)** released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
1. **[T5](https://huggingface.co/transformers/model_doc/t5.html)** (from Google AI) released with the paper [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TAPAS](https://huggingface.co/transformers/model_doc/tapas.html)** (from Google AI) released with the paper [TAPAS: Weakly Supervised Table Parsing via Pre-training](https://arxiv.org/abs/2004.02349) by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller, Francesco Piccinno and Julian Martin Eisenschlos.
......
......@@ -191,31 +191,34 @@ and conversion utilities for the following models:
36. :doc:`RoBERTa <model_doc/roberta>` (from Facebook), released together with the paper a `Robustly Optimized BERT
Pretraining Approach <https://arxiv.org/abs/1907.11692>`__ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar
Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
37. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
37. :doc:`SpeechToTextTransformer <model_doc/speech_to_text>` (from Facebook), released together with the paper
`fairseq S2T: Fast Speech-to-Text Modeling with fairseq <https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun
Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino.
38. :doc:`SqueezeBert <model_doc/squeezebert>` released with the paper `SqueezeBERT: What can computer vision teach NLP
about efficient neural networks? <https://arxiv.org/abs/2006.11316>`__ by Forrest N. Iandola, Albert E. Shaw, Ravi
Krishna, and Kurt W. Keutzer.
38. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
39. :doc:`T5 <model_doc/t5>` (from Google AI) released with the paper `Exploring the Limits of Transfer Learning with a
Unified Text-to-Text Transformer <https://arxiv.org/abs/1910.10683>`__ by Colin Raffel and Noam Shazeer and Adam
Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
39. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
40. :doc:`TAPAS <model_doc/tapas>` (from Google AI) released with the paper `TAPAS: Weakly Supervised Table Parsing via
Pre-training <https://arxiv.org/abs/2004.02349>`__ by Jonathan Herzig, Paweł Krzysztof Nowak, Thomas Müller,
Francesco Piccinno and Julian Martin Eisenschlos.
40. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
41. :doc:`Transformer-XL <model_doc/transformerxl>` (from Google/CMU) released with the paper `Transformer-XL:
Attentive Language Models Beyond a Fixed-Length Context <https://arxiv.org/abs/1901.02860>`__ by Zihang Dai*,
Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
41. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
42. :doc:`Wav2Vec2 <model_doc/wav2vec2>` (from Facebook AI) released with the paper `wav2vec 2.0: A Framework for
Self-Supervised Learning of Speech Representations <https://arxiv.org/abs/2006.11477>`__ by Alexei Baevski, Henry
Zhou, Abdelrahman Mohamed, Michael Auli.
42. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
43. :doc:`XLM <model_doc/xlm>` (from Facebook) released together with the paper `Cross-lingual Language Model
Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
43. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
44. :doc:`XLM-ProphetNet <model_doc/xlmprophetnet>` (from Microsoft Research) released with the paper `ProphetNet:
Predicting Future N-gram for Sequence-to-Sequence Pre-training <https://arxiv.org/abs/2001.04063>`__ by Yu Yan,
Weizhen Qi, Yeyun Gong, Dayiheng Liu, Nan Duan, Jiusheng Chen, Ruofei Zhang and Ming Zhou.
44. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
45. :doc:`XLM-RoBERTa <model_doc/xlmroberta>` (from Facebook AI), released together with the paper `Unsupervised
Cross-lingual Representation Learning at Scale <https://arxiv.org/abs/1911.02116>`__ by Alexis Conneau*, Kartikay
Khandelwal*, Naman Goyal, Vishrav Chaudhary, Guillaume Wenzek, Francisco Guzmán, Edouard Grave, Myle Ott, Luke
Zettlemoyer and Veselin Stoyanov.
45. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive
46. :doc:`XLNet <model_doc/xlnet>` (from Google/CMU) released with the paper `​XLNet: Generalized Autoregressive
Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang*, Zihang Dai*, Yiming
Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
......@@ -304,6 +307,8 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| RoBERTa | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Speech2Text | ✅ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| SqueezeBERT | ✅ | ✅ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| T5 | ✅ | ✅ | ✅ | ✅ | ❌ |
......@@ -436,6 +441,7 @@ TensorFlow and/or Flax.
model_doc/reformer
model_doc/retribert
model_doc/roberta
model_doc/speech_to_text
model_doc/squeezebert
model_doc/t5
model_doc/tapas
......
..
Copyright 2021 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
Speech2Text
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The Speech2Text model was proposed in `fairseq S2T: Fast Speech-to-Text Modeling with fairseq
<https://arxiv.org/abs/2010.05171>`__ by Changhan Wang, Yun Tang, Xutai Ma, Anne Wu, Dmytro Okhonko, Juan Pino. It's a
transformer-based seq2seq (encoder-decoder) model designed for end-to-end Automatic Speech Recognition (ASR) and Speech
Translation (ST). It uses a convolutional downsampler to reduce the length of speech inputs by 3/4th before they are
fed into the encoder. The model is trained with standard autoregressive cross-entropy loss and generates the
transcripts/translations autoregressively. Speech2Text has been fine-tuned on several datasets for ASR and ST:
`LibriSpeech <http://www.openslr.org/12>`__, `CoVoST 2 <https://github.com/facebookresearch/covost>`__, `MuST-C
<https://ict.fbk.eu/must-c/>`__.
The original code can be found `here <https://github.com/pytorch/fairseq/tree/master/examples/speech_to_text>`__.
Inference
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Speech2Text is a speech model that accepts a float tensor of log-mel filter-bank features extracted from the speech
signal. It's a transformer-based seq2seq model, so the transcripts/translations are generated autoregressively. The
:obj:`generate()` method can be used for inference.
The :class:`~transformers.Speech2TextFeatureExtractor` class is responsible for extracting the log-mel filter-bank
features. The :class:`~transformers.Speech2TextProcessor` wraps :class:`~transformers.Speech2TextFeatureExtractor` and
:class:`~transformers.Speech2TextTokenizer` into a single instance to both extract the input features and decode the
predicted token ids.
The feature extractor depends on :obj:`torchaudio` and the tokenizer depends on :obj:`sentencepiece` so be sure to
install those packages before running the examples. You could either install those as extra speech dependancies with
``pip install transformers"[speech, sentencepiece]"`` or install the packages seperatly with ``pip install torchaudio
sentencepiece``. Also ``torchaudio`` requires the development version of the `libsndfile
<http://www.mega-nerd.com/libsndfile/>`__ package which can be installed via a system package manager. On Ubuntu it can
be installed as follows: ``apt install libsndfile1-dev``
- ASR and Speech Translation
.. code-block::
>>> import torch
>>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
>>> processor = Speech2Textprocessor.from_pretrained("facebook/s2t-small-librispeech-asr")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_features = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="pt").input_features # Batch size 1
>>> generated_ids = model.generate(input_ids=input_features)
>>> transcription = processor.batch_decode(generated_ids)
- Multilingual speech translation
For multilingual speech translation models, :obj:`eos_token_id` is used as the :obj:`decoder_start_token_id` and
the target language id is forced as the first generated token. To force the target language id as the first
generated token, pass the :obj:`forced_bos_token_id` parameter to the :obj:`generate()` method. The following
example shows how to transate English speech to French text using the `facebook/s2t-medium-mustc-multilingual-st`
checkpoint.
.. code-block::
>>> import torch
>>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
>>> from datasets import load_dataset
>>> import soundfile as sf
>>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
>>> processor = Speech2Textprocessor.from_pretrained("facebook/s2t-medium-mustc-multilingual-st")
>>> def map_to_array(batch):
... speech, _ = sf.read(batch["file"])
... batch["speech"] = speech
... return batch
>>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
>>> ds = ds.map(map_to_array)
>>> input_features = processor(ds["speech"][0], sampling_rate=16_000, return_tensors="pt").input_features # Batch size 1
>>> generated_ids = model.generate(input_ids=input_features, forced_bos_token_id=processor.tokenizer.lang_code_to_id["fr"])
>>> translation = processor.batch_decode(generated_ids)
See the `model hub <https://huggingface.co/models?filter=speech_to_text>`__ to look for Speech2Text checkpoints.
Speech2TextConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Speech2TextConfig
:members:
Speech2TextTokenizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Speech2TextTokenizer
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, save_vocabulary
Speech2TextFeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Speech2TextFeatureExtractor
:members: __call__
Speech2TextProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Speech2TextProcessor
:members: __call__, from_pretrained, save_pretrained, batch_decode, decode, as_target_processor
Speech2TextModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Speech2TextModel
:members: forward
Speech2TextForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Speech2TextForConditionalGeneration
:members: forward
......@@ -35,6 +35,7 @@ known_third_party =
tensorflow_datasets
timeout_decorator
torch
torchaudio
torchtext
torchvision
torch_xla
......
......@@ -134,6 +134,7 @@ _deps = [
"timeout-decorator",
"tokenizers>=0.10.1,<0.11",
"torch>=1.0",
"torchaudio",
"tqdm>=4.27",
"unidic>=1.0.2",
"unidic_lite>=1.0.7",
......@@ -227,14 +228,13 @@ extras["onnxruntime"] = deps_list("onnxruntime", "onnxruntime-tools")
extras["modelcreation"] = deps_list("cookiecutter")
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
extras["speech"] = deps_list("soundfile")
extras["speech"] = deps_list("soundfile", "torchaudio")
extras["sentencepiece"] = deps_list("sentencepiece", "protobuf")
extras["testing"] = (
deps_list("pytest", "pytest-xdist", "timeout-decorator", "parameterized", "psutil", "datasets")
+ extras["retrieval"]
+ extras["modelcreation"]
+ extras["speech"]
)
extras["docs"] = deps_list("recommonmark", "sphinx", "sphinx-markdown-tables", "sphinx-rtd-theme", "sphinx-copybutton")
extras["quality"] = deps_list("black", "isort", "flake8")
......
......@@ -135,6 +135,11 @@ _import_structure = {
"Wav2Vec2Processor",
],
"models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config", "M2M100Tokenizer"],
"models.speech_to_text": [
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig",
"Speech2TextFeatureExtractor",
],
"models.convbert": ["CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "ConvBertConfig", "ConvBertTokenizer"],
"models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
"models.auto": [
......@@ -275,6 +280,8 @@ if is_sentencepiece_available():
_import_structure["models.mt5"].append("MT5Tokenizer")
_import_structure["models.pegasus"].append("PegasusTokenizer")
_import_structure["models.reformer"].append("ReformerTokenizer")
_import_structure["models.speech_to_text"].append("Speech2TextTokenizer")
_import_structure["models.speech_to_text"].append("Speech2TextProcessor")
_import_structure["models.t5"].append("T5Tokenizer")
_import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer")
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer")
......@@ -377,6 +384,14 @@ if is_torch_available():
_import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"]
# PyTorch models structure
_import_structure["models.speech_to_text"].extend(
[
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2TextForConditionalGeneration",
"Speech2TextModel",
]
)
_import_structure["models.wav2vec2"].extend(
[
"WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -1379,6 +1394,11 @@ if TYPE_CHECKING:
from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig
from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer
from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer
from .models.speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Speech2TextConfig,
Speech2TextFeatureExtractor,
)
from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer
from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer
......@@ -1461,6 +1481,7 @@ if TYPE_CHECKING:
from .models.mt5 import MT5Tokenizer
from .models.pegasus import PegasusTokenizer
from .models.reformer import ReformerTokenizer
from .models.speech_to_text import Speech2TextProcessor, Speech2TextTokenizer
from .models.t5 import T5Tokenizer
from .models.xlm_prophetnet import XLMProphetNetTokenizer
from .models.xlm_roberta import XLMRobertaTokenizer
......@@ -1862,6 +1883,11 @@ if TYPE_CHECKING:
RobertaForTokenClassification,
RobertaModel,
)
from .models.speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
Speech2TextForConditionalGeneration,
Speech2TextModel,
)
from .models.squeezebert import (
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
SqueezeBertForMaskedLM,
......
......@@ -47,6 +47,7 @@ deps = {
"timeout-decorator": "timeout-decorator",
"tokenizers": "tokenizers>=0.10.1,<0.11",
"torch": "torch>=1.0",
"torchaudio": "torchaudio",
"tqdm": "tqdm>=4.27",
"unidic": "unidic>=1.0.2",
"unidic_lite": "unidic_lite>=1.0.7",
......
......@@ -177,6 +177,13 @@ try:
except importlib_metadata.PackageNotFoundError:
_soundfile_available = False
_torchaudio_available = importlib.util.find_spec("torchaudio")
try:
_torchaudio_version = importlib_metadata.version("torchaudio")
logger.debug(f"Successfully imported soundfile version {_torchaudio_version}")
except importlib_metadata.PackageNotFoundError:
_torchaudio_available = False
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
......@@ -364,6 +371,10 @@ def is_soundfile_availble():
return _soundfile_available
def is_torchaudio_available():
return _torchaudio_available
def torch_only_method(fn):
def wrapper(*args, **kwargs):
if not _torch_available:
......
......@@ -384,7 +384,7 @@ class GenerationMixin:
)
if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id:
return input_ids.ne(pad_token_id).long()
return input_ids.new_ones(input_ids.shape)
return input_ids.new_ones(input_ids.shape, dtype=torch.long)
def _prepare_encoder_decoder_kwargs_for_generation(
self, input_ids: torch.LongTensor, model_kwargs
......@@ -402,8 +402,7 @@ class GenerationMixin:
) -> torch.LongTensor:
decoder_start_token_id = self._get_decoder_start_token_id(decoder_start_token_id, bos_token_id)
decoder_input_ids = (
torch.ones((input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device)
* decoder_start_token_id
torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device) * decoder_start_token_id
)
return decoder_input_ids
......
......@@ -60,6 +60,7 @@ from . import (
reformer,
retribert,
roberta,
speech_to_text,
squeezebert,
t5,
tapas,
......
......@@ -58,6 +58,10 @@ from ..rag.configuration_rag import RagConfig
from ..reformer.configuration_reformer import ReformerConfig
from ..retribert.configuration_retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig
from ..roberta.configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from ..speech_to_text.configuration_speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
Speech2TextConfig,
)
from ..squeezebert.configuration_squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig
from ..t5.configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from ..tapas.configuration_tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig
......@@ -76,6 +80,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
(key, value)
for pretrained_map in [
# Add archive maps here
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP,
CONVBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -122,6 +127,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
CONFIG_MAPPING = OrderedDict(
[
# Add configs here
("speech_to_text", Speech2TextConfig),
("wav2vec2", Wav2Vec2Config),
("m2m_100", M2M100Config),
("convbert", ConvBertConfig),
......@@ -174,6 +180,7 @@ CONFIG_MAPPING = OrderedDict(
MODEL_NAMES_MAPPING = OrderedDict(
[
# Add full (and cased) model names here
("speech_to_text", "Speech2Text"),
("wav2vec2", "Wav2Vec2"),
("m2m_100", "M2M100"),
("convbert", "ConvBERT"),
......
......@@ -66,8 +66,6 @@ from ..camembert.modeling_camembert import (
CamembertForTokenClassification,
CamembertModel,
)
# Add modeling imports here
from ..convbert.modeling_convbert import (
ConvBertForMaskedLM,
ConvBertForMultipleChoice,
......@@ -211,6 +209,7 @@ from ..roberta.modeling_roberta import (
RobertaForTokenClassification,
RobertaModel,
)
from ..speech_to_text.modeling_speech_to_text import Speech2TextForConditionalGeneration, Speech2TextModel
from ..squeezebert.modeling_squeezebert import (
SqueezeBertForMaskedLM,
SqueezeBertForMultipleChoice,
......@@ -296,6 +295,7 @@ from .configuration_auto import (
ReformerConfig,
RetriBertConfig,
RobertaConfig,
Speech2TextConfig,
SqueezeBertConfig,
T5Config,
TapasConfig,
......@@ -315,6 +315,7 @@ logger = logging.get_logger(__name__)
MODEL_MAPPING = OrderedDict(
[
# Base model mapping
(Speech2TextConfig, Speech2TextModel),
(Wav2Vec2Config, Wav2Vec2Model),
(M2M100Config, M2M100Model),
(ConvBertConfig, ConvBertModel),
......@@ -399,6 +400,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
[
# Model with LM heads mapping
(Speech2TextConfig, Speech2TextForConditionalGeneration),
(Wav2Vec2Config, Wav2Vec2ForMaskedLM),
(M2M100Config, M2M100ForConditionalGeneration),
(ConvBertConfig, ConvBertForMaskedLM),
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_sentencepiece_available, is_torch_available
_import_structure = {
"configuration_speech_to_text": [
"SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP",
"Speech2TextConfig",
],
"feature_extraction_speech_to_text": ["Speech2TextFeatureExtractor"],
}
if is_sentencepiece_available():
_import_structure["tokenization_speech_to_text"] = ["Speech2TextTokenizer"]
_import_structure["processing_speech_to_text"] = ["Speech2TextProcessor"]
if is_torch_available():
_import_structure["modeling_speech_to_text"] = [
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2TextForConditionalGeneration",
"Speech2TextModel",
"Speech2TextPreTrainedModel",
]
if TYPE_CHECKING:
from .configuration_speech_to_text import SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP, Speech2TextConfig
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
if is_sentencepiece_available():
from .processing_speech_to_text import Speech2TextProcessor
from .tokenization_speech_to_text import Speech2TextTokenizer
if is_torch_available():
from .modeling_speech_to_text import (
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
Speech2TextForConditionalGeneration,
Speech2TextModel,
Speech2TextPreTrainedModel,
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Speech2Text model configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
SPEECH_TO_TEXT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/config.json",
# See all Speech2Text models at https://huggingface.co/models?filter=speech_to_text
}
class Speech2TextConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.Speech2TextModel`. It is used
to instantiate an Speech2Text model according to the specified arguments, defining the model architecture.
Instantiating a configuration with the defaults will yield a similar configuration to that of the Speech2Text
`facebook/s2t-small-librispeech-asr <https://huggingface.co/facebook/s2t-small-librispeech-asr>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 50265):
Vocabulary size of the Speech2Text model. Defines the number of different tokens that can be represented by
the :obj:`inputs_ids` passed when calling :class:`~transformers.Speech2TextModel`
d_model (:obj:`int`, `optional`, defaults to 1024):
Dimensionality of the layers and the pooler layer.
encoder_layers (:obj:`int`, `optional`, defaults to 12):
Number of encoder layers.
decoder_layers (:obj:`int`, `optional`, defaults to 12):
Number of decoder layers.
encoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
decoder_attention_heads (:obj:`int`, `optional`, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
decoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
encoder_ffn_dim (:obj:`int`, `optional`, defaults to 4096):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for the attention probabilities.
activation_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for activations inside the fully connected layer.
classifier_dropout (:obj:`float`, `optional`, defaults to 0.0):
The dropout ratio for classifier.
init_std (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
encoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the encoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
max_source_positions (:obj:`int`, `optional`, defaults to 6000):
The maximum sequence length of log-mel filter-bank features that this model might ever be used with.
max_target_positions: (:obj:`int`, `optional`, defaults to 1024):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
num_conv_layers (:obj:`int`, `optional`, defaults to 2):
Number of 1D convolutional layers in the conv module.
conv_kernel_sizes (:obj:`Tuple[int]`, `optional`, defaults to :obj:`(5, 5)`):
A tuple of integers defining the kernel size of each 1D convolutional layer in the conv module. The length
of :obj:`conv_kernel_sizes` has to match :obj:`num_conv_layers`.
conv_channels (:obj:`int`, `optional`, defaults to 1024):
An integer defining the number of output channels of each convolution layers except the final one in the
conv module.
input_feat_per_channel (:obj:`int`, `optional`, defaults to 80):
An integer specifying the size of feature vector. This is also the dimentions of log-mel filter-bank
features.
input_channels (:obj:`int`, `optional`, defaults to 1):
An integer specifying number of input channels of the input feature vector.
Example::
>>> from transformers import Speech2TextModel, Speech2TextConfig
>>> # Initializing a Speech2Text s2t_transformer_s style configuration
>>> configuration = Speech2TextConfig()
>>> # Initializing a model from the s2t_transformer_s style configuration
>>> model = Speech2TextModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
"""
model_type = "speech_to_text"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=10000,
encoder_layers=12,
encoder_ffn_dim=2048,
encoder_attention_heads=4,
decoder_layers=6,
decoder_ffn_dim=2048,
decoder_attention_heads=4,
encoder_layerdrop=0.0,
decoder_layerdrop=0.0,
use_cache=True,
is_encoder_decoder=True,
activation_function="relu",
d_model=256,
dropout=0.1,
attention_dropout=0.0,
activation_dropout=0.0,
init_std=0.02,
decoder_start_token_id=2,
classifier_dropout=0.0,
scale_embedding=True,
gradient_checkpointing=False,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
max_source_positions=6000,
max_target_positions=1024,
num_conv_layers=2,
conv_kernel_sizes=(5, 5),
conv_channels=1024,
input_feat_per_channel=80,
input_channels=1,
**kwargs
):
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
is_encoder_decoder=is_encoder_decoder,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model
self.encoder_ffn_dim = encoder_ffn_dim
self.encoder_layers = encoder_layers
self.encoder_attention_heads = encoder_attention_heads
self.decoder_ffn_dim = decoder_ffn_dim
self.decoder_layers = decoder_layers
self.decoder_attention_heads = decoder_attention_heads
self.dropout = dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.activation_function = activation_function
self.init_std = init_std
self.encoder_layerdrop = encoder_layerdrop
self.decoder_layerdrop = decoder_layerdrop
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions
self.num_conv_layers = num_conv_layers
self.conv_kernel_sizes = list(conv_kernel_sizes)
self.conv_channels = conv_channels
self.input_feat_per_channel = input_feat_per_channel
self.input_channels = input_channels
if len(self.conv_kernel_sizes) != self.num_conv_layers:
raise ValueError(
"Configuration for convolutional module is incorrect."
"It is required that `len(config.conv_kernel_sizes)` == `config.num_conv_layers`"
f"but is `len(config.conv_kernel_sizes) = {len(self.conv_kernel_sizes)}`,"
f"`config.num_conv_layers = {self.num_conv_layers}`."
)
@property
def num_attention_heads(self) -> int:
return self.encoder_attention_heads
@property
def hidden_size(self) -> int:
return self.d_model
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import torch
from torch import nn
from transformers import Speech2TextConfig, Speech2TextForConditionalGeneration
def remove_ignore_keys_(state_dict):
ignore_keys = [
"encoder.version",
"decoder.version",
"model.encoder.version",
"model.decoder.version",
"decoder.output_projection.weight",
"_float_tensor",
"encoder.embed_positions._float_tensor",
"decoder.embed_positions._float_tensor",
]
for k in ignore_keys:
state_dict.pop(k, None)
def rename_keys(s_dict):
keys = list(s_dict.keys())
for key in keys:
if "transformer_layers" in key:
s_dict[key.replace("transformer_layers", "layers")] = s_dict.pop(key)
elif "subsample" in key:
s_dict[key.replace("subsample", "conv")] = s_dict.pop(key)
def make_linear_from_emb(emb):
vocab_size, emb_size = emb.weight.shape
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
lin_layer.weight.data = emb.weight.data
return lin_layer
def convert_fairseq_s2t_checkpoint_to_tfms(checkpoint_path, pytorch_dump_folder_path):
m2m_100 = torch.load(checkpoint_path, map_location="cpu")
args = m2m_100["args"]
state_dict = m2m_100["model"]
lm_head_weights = state_dict["decoder.output_projection.weight"]
remove_ignore_keys_(state_dict)
rename_keys(state_dict)
vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0]
tie_embeds = args.share_decoder_input_output_embed
conv_kernel_sizes = [int(i) for i in args.conv_kernel_sizes.split(",")]
config = Speech2TextConfig(
vocab_size=vocab_size,
max_source_positions=args.max_source_positions,
max_target_positions=args.max_target_positions,
encoder_layers=args.encoder_layers,
decoder_layers=args.decoder_layers,
encoder_attention_heads=args.encoder_attention_heads,
decoder_attention_heads=args.decoder_attention_heads,
encoder_ffn_dim=args.encoder_ffn_embed_dim,
decoder_ffn_dim=args.decoder_ffn_embed_dim,
d_model=args.encoder_embed_dim,
dropout=args.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_function="relu",
num_conv_layers=len(conv_kernel_sizes),
conv_channels=args.conv_channels,
conv_kernel_sizes=conv_kernel_sizes,
input_feat_per_channel=args.input_feat_per_channel,
input_channels=args.input_channels,
tie_word_embeddings=tie_embeds,
num_beams=5,
max_length=200,
use_cache=True,
decoder_start_token_id=2,
early_stopping=True,
)
model = Speech2TextForConditionalGeneration(config)
model.model.load_state_dict(state_dict)
if tie_embeds:
model.lm_head = make_linear_from_emb(model.model.decoder.embed_tokens)
else:
model.lm_head.weight.data = lm_head_weights
model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("fairseq_path", type=str, help="Path to the fairseq model (.pt) file.")
parser.add_argument("pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
args = parser.parse_args()
convert_fairseq_s2t_checkpoint_to_tfms(args.fairseq_path, args.pytorch_dump_folder_path)
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Feature extractor class for Speech2Text
"""
from typing import List, Optional, Union
import numpy as np
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature
from ...file_utils import PaddingStrategy, TensorType, is_torch_available, is_torchaudio_available
from ...utils import logging
if is_torch_available():
import torch
if is_torchaudio_available():
import torchaudio.compliance.kaldi as ta_kaldi
logger = logging.get_logger(__name__)
class Speech2TextFeatureExtractor(SequenceFeatureExtractor):
r"""
Constructs a Speech2Text feature extractor.
This feature extractor inherits from :class:`~transformers.Speech2TextFeatureExtractor` which contains most of the
main methods. Users should refer to this superclass for more information regarding those methods.
This class extracts mel-filter bank features from raw speech using TorchAudio and applies utterance-level cepstral
mean and variance normalization to the extracted features.
Args:
feature_size (:obj:`int`, defaults to 80):
The feature dimension of the extracted features.
sampling_rate (:obj:`int`, defaults to 16000):
The sampling rate at which the audio files should be digitalized expressed in Hertz per second (Hz).
num_mel_bins (:obj:`int`, defaults to 80):
Number of Mel-frequency bins.
padding_value (:obj:`float`, defaults to 0.0):
The value that is used to fill the padding vectors.
do_ceptral_normalize (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to apply utterance-level cepstral mean and variance normalization to extracted features.
normalize_means (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to zero-mean normalize the extracted features.
normalize_vars (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to unit-variance normalize the extracted features.
"""
model_input_names = ["input_features", "attention_mask"]
def __init__(
self,
feature_size=80,
sampling_rate=16000,
num_mel_bins=80,
padding_value=0.0,
do_ceptral_normalize=True,
normalize_means=True,
normalize_vars=True,
**kwargs
):
if not is_torchaudio_available():
raise ImportError("`Speech2TextFeatureExtractor` requires torchaudio: `pip install torchaudio`.")
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
self.num_mel_bins = num_mel_bins
self.do_ceptral_normalize = do_ceptral_normalize
self.normalize_means = normalize_means
self.normalize_vars = normalize_vars
self.return_attention_mask = True
def _extract_fbank_features(
self,
waveform: np.ndarray,
) -> np.ndarray:
"""
Get mel-filter bank features using TorchAudio. Note that TorchAudio requires 16-bit signed integers as inputs
and hence the waveform should not be normalized before feature extraction.
"""
waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers
waveform = torch.from_numpy(waveform).unsqueeze(0)
features = ta_kaldi.fbank(waveform, num_mel_bins=self.num_mel_bins, sample_frequency=self.sampling_rate)
return features.numpy()
@staticmethod
def utterance_cmvn(
x: np.ndarray, normalize_means: Optional[bool] = True, normalize_vars: Optional[bool] = True
) -> np.ndarray:
mean = x.mean(axis=0)
square_sums = (x ** 2).sum(axis=0)
if normalize_means:
x = np.subtract(x, mean)
if normalize_vars:
var = square_sums / x.shape[0] - mean ** 2
std = np.sqrt(np.maximum(var, 1e-10))
x = np.divide(x, std)
return x
def normalize(self, input_values: List[np.ndarray]) -> List[np.ndarray]:
return [self.utterance_cmvn(x, self.normalize_means, self.normalize_vars) for x in input_values]
def __call__(
self,
raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
padding: Union[bool, str, PaddingStrategy] = False,
max_length: Optional[int] = None,
pad_to_multiple_of: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
sampling_rate: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
**kwargs
) -> BatchFeature:
"""
Main method to featurize and prepare for the model one or several sequence(s). sequences.
Args:
raw_speech (:obj:`np.ndarray`, :obj:`List[float]`, :obj:`List[np.ndarray]`, :obj:`List[List[float]]`):
The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
values, a list of numpy arrays or a list of list of float values.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.file_utils.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
max_length (:obj:`int`, `optional`):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
>= 7.5 (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128.
return_attention_mask (:obj:`bool`, `optional`):
Whether to return the attention mask. If left to the default, will return the attention mask according
to the specific feature_extractor's default.
`What are attention masks? <../glossary.html#attention-mask>`__
.. note::
For Speech2TextTransoformer models, :obj:`attention_mask` should alwys be passed for batched
inference, to avoid subtle bugs.
return_tensors (:obj:`str` or :class:`~transformers.file_utils.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
sampling_rate (:obj:`int`, `optional`):
The sampling rate at which the :obj:`raw_speech` input was sampled. It is strongly recommended to pass
:obj:`sampling_rate` at the forward call to prevent silent errors.
padding_value (:obj:`float`, defaults to 0.0):
The value that is used to fill the padding values / vectors.
"""
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of {self.sampling_rate}."
f"Please make sure that the provided `raw_speech` input was sampled with {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function."
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
)
# make sure input is in list format
if is_batched and not isinstance(raw_speech[0], np.ndarray):
raw_speech = [np.asarray(speech) for speech in raw_speech]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech)
# always return batch
if not is_batched:
raw_speech = [raw_speech]
# extract fbank features
features = [self._extract_fbank_features(waveform) for waveform in raw_speech]
# Utterance-level cepstral mean and variance normalization
if self.do_ceptral_normalize:
features = self.normalize(features)
# convert into correct format for padding
encoded_inputs = BatchFeature({"input_features": features})
padded_inputs = self.pad(
encoded_inputs,
padding=padding,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask,
return_tensors=return_tensors,
**kwargs,
)
return padded_inputs
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Speech processor class for Speech2Text
"""
from contextlib import contextmanager
from .feature_extraction_speech_to_text import Speech2TextFeatureExtractor
from .tokenization_speech_to_text import Speech2TextTokenizer
class Speech2TextProcessor:
r"""
Constructs a Speech2Text processor which wraps a Speech2Text feature extractor and a Speech2Text tokenizer into a
single processor.
:class:`~transformers.Speech2TextProcessor` offers all the functionalities of
:class:`~transformers.Speech2TextFeatureExtractor` and :class:`~transformers.Speech2TextTokenizer`. See the
:meth:`~transformers.Speech2TextProcessor.__call__` and :meth:`~transformers.Speech2TextProcessor.decode` for more
information.
Args:
feature_extractor (:obj:`Speech2TextFeatureExtractor`):
An instance of :class:`~transformers.Speech2TextFeatureExtractor`. The feature extractor is a required
input.
tokenizer (:obj:`Speech2TextTokenizer`):
An instance of :class:`~transformers.Speech2TextTokenizer`. The tokenizer is a required input.
"""
def __init__(self, feature_extractor, tokenizer):
if not isinstance(feature_extractor, Speech2TextFeatureExtractor):
raise ValueError(
f"`feature_extractor` has to be of type {Speech2TextFeatureExtractor.__class__}, but is {type(feature_extractor)}"
)
if not isinstance(tokenizer, Speech2TextTokenizer):
raise ValueError(
f"`tokenizer` has to be of type {Speech2TextTokenizer.__class__}, but is {type(tokenizer)}"
)
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.current_processor = self.feature_extractor
def save_pretrained(self, save_directory):
"""
Save a Speech2Text feature extractor object and Speech2Text tokenizer object to the directory
``save_directory``, so that it can be re-loaded using the
:func:`~transformers.Speech2TextProcessor.from_pretrained` class method.
.. note::
This class method is simply calling :meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` and
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.save_pretrained`. Please refer to the
docstrings of the methods above for more information.
Args:
save_directory (:obj:`str` or :obj:`os.PathLike`):
Directory where the feature extractor JSON file and the tokenizer files will be saved (directory will
be created if it does not exist).
"""
self.feature_extractor.save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
r"""
Instantiate a :class:`~transformers.Speech2TextProcessor` from a pretrained Speech2Text processor.
.. note::
This class method is simply calling Speech2TextFeatureExtractor's
:meth:`~transformers.PreTrainedFeatureExtractor.from_pretrained` and Speech2TextTokenizer's
:meth:`~transformers.tokenization_utils_base.PreTrainedTokenizer.from_pretrained`. Please refer to the
docstrings of the methods above for more information.
Args:
pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`):
This can be either:
- a string, the `model id` of a pretrained feature_extractor hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like ``bert-base-uncased``, or
namespaced under a user or organization name, like ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing a feature extractor file saved using the
:meth:`~transformers.PreTrainedFeatureExtractor.save_pretrained` method, e.g.,
``./my_model_directory/``.
- a path or url to a saved feature extractor JSON `file`, e.g.,
``./my_model_directory/feature_extraction_config.json``.
**kwargs
Additional keyword arguments passed along to both :class:`~transformers.PreTrainedFeatureExtractor` and
:class:`~transformers.PreTrainedTokenizer`
"""
feature_extractor = Speech2TextFeatureExtractor.from_pretrained(pretrained_model_name_or_path, **kwargs)
tokenizer = Speech2TextTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(feature_extractor=feature_extractor, tokenizer=tokenizer)
def __call__(self, *args, **kwargs):
"""
When used in normal mode, this method forwards all its arguments to Speech2TextFeatureExtractor's
:meth:`~transformers.Speech2TextFeatureExtractor.__call__` and returns its output. If used in the context
:meth:`~transformers.Speech2TextProcessor.as_target_processor` this method forwards all its arguments to
Speech2TextTokenizer's :meth:`~transformers.Speech2TextTokenizer.__call__`. Please refer to the doctsring of
the above two methods for more information.
"""
return self.current_processor(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Speech2TextTokenizer's
:meth:`~transformers.PreTrainedTokenizer.batch_decode`. Please refer to the docstring of this method for more
information.
"""
return self.tokenizer.batch_decode(*args, **kwargs)
def decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Speech2TextTokenizer's
:meth:`~transformers.PreTrainedTokenizer.decode`. Please refer to the docstring of this method for more
information.
"""
return self.tokenizer.decode(*args, **kwargs)
@contextmanager
def as_target_processor(self):
"""
Temporarily sets the tokenizer for processing the input. Useful for encoding the labels when fine-tuning
Speech2Text.
"""
self.current_processor = self.tokenizer
yield
self.current_processor = self.feature_extractor
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for Speech2Text."""
import json
from pathlib import Path
from shutil import copyfile
from typing import Dict, List, Optional, Tuple, Union
import sentencepiece
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import logging
logger = logging.get_logger(__name__)
SPIECE_UNDERLINE = "▁"
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"spm_file": "sentencepiece.bpe.model",
}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/vocab.json",
},
"spm_file": {
"facebook/s2t-small-librispeech-asr": "https://huggingface.co/facebook/s2t-small-librispeech-asr/resolve/main/sentencepiece.bpe.model"
},
}
MAX_MODEL_INPUT_SIZES = {
"facebook/s2t-small-librispeech-asr": 1024,
}
MUSTC_LANGS = ["pt", "fr", "ru", "nl", "ro", "it", "es", "de"]
LANGUAGES = {"mustc": MUSTC_LANGS}
class Speech2TextTokenizer(PreTrainedTokenizer):
"""
Construct an Speech2Text tokenizer.
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains some of the main methods.
Users should refer to the superclass for more information regarding such methods.
Args:
vocab_file (:obj:`str`):
File containing the vocabulary.
spm_file (:obj:`str`):
Path to the `SentencePiece <https://github.com/google/sentencepiece>`__ model file
bos_token (:obj:`str`, `optional`, defaults to :obj:`"<s>"`):
The beginning of sentence token.
eos_token (:obj:`str`, `optional`, defaults to :obj:`"</s>"`):
The end of sentence token.
unk_token (:obj:`str`, `optional`, defaults to :obj:`"<unk>"`):
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
token instead.
pad_token (:obj:`str`, `optional`, defaults to :obj:`"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
do_upper_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to uppercase the output when decoding.
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to lowercase the input when tokenizing.
tgt_lang (:obj:`str`, `optional`):
A string representing the target language.
**kwargs
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = MAX_MODEL_INPUT_SIZES
model_input_names = ["input_ids", "attention_mask"]
prefix_tokens: List[int] = []
def __init__(
self,
vocab_file,
spm_file,
bos_token="<s>",
eos_token="</s>",
pad_token="<pad>",
unk_token="<unk>",
do_upper_case=False,
do_lower_case=False,
tgt_lang=None,
lang_codes=None,
**kwargs,
):
super().__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
pad_token=pad_token,
do_upper_case=do_upper_case,
do_lower_case=do_lower_case,
tgt_lang=tgt_lang,
lang_codes=lang_codes,
**kwargs,
)
self.do_upper_case = do_upper_case
self.do_lower_case = do_lower_case
self.encoder = load_json(vocab_file)
self.decoder = {v: k for k, v in self.encoder.items()}
self.spm_file = spm_file
self.sp_model = load_spm(spm_file)
if lang_codes is not None:
self.lang_codes = lang_codes
self.langs = LANGUAGES[lang_codes]
self.lang_tokens = [f"<lang:{lang}>" for lang in self.langs]
self.lang_code_to_id = {lang: self.sp_model.PieceToId(f"<lang:{lang}>") for lang in self.langs}
self._additional_special_tokens = self.lang_tokens
self._tgt_lang = tgt_lang if tgt_lang is not None else self.langs[0]
self.set_tgt_lang_special_tokens(self._tgt_lang)
else:
self.lang_code_to_id = {}
@property
def vocab_size(self) -> int:
return len(self.encoder)
@property
def tgt_lang(self) -> str:
return self._tgt_lang
@tgt_lang.setter
def tgt_lang(self, new_tgt_lang) -> None:
self._tgt_lang = new_tgt_lang
self.set_tgt_lang_special_tokens(new_tgt_lang)
def set_tgt_lang_special_tokens(self, tgt_lang: str) -> None:
"""Reset the special tokens to the target language setting. prefix=[eos, tgt_lang_code] and suffix=[eos]."""
lang_code_id = self.lang_code_to_id[tgt_lang]
self.prefix_tokens = [lang_code_id]
def _tokenize(self, text: str) -> List[str]:
return self.sp_model.EncodeAsPieces(text)
def _convert_token_to_id(self, token):
return self.encoder.get(token, self.encoder[self.unk_token])
def _convert_id_to_token(self, index: int) -> str:
"""Converts an index (integer) in a token (str) using the decoder."""
return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
if self.do_upper_case:
out_string = out_string.upper()
return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0 + [self.eos_token_id]
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1 + [self.eos_token_id]
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
"""
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` method.
Args:
token_ids_0 (:obj:`List[int]`):
List of IDs.
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the token list is already formatted with special tokens for the model.
Returns:
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.bos_token_id, self.eos_token_id] else 0, token_ids_0))
prefix_ones = [1] * len(self.prefix_tokens)
suffix_ones = [1]
if token_ids_1 is None:
return prefix_ones + ([0] * len(token_ids_0)) + suffix_ones
return prefix_ones + ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def get_vocab(self) -> Dict:
vocab = self.encoder.copy()
vocab.update(self.added_tokens_encoder)
return vocab
def __getstate__(self) -> Dict:
state = self.__dict__.copy()
state["sp_model"] = None
return state
def __setstate__(self, d: Dict) -> None:
self.__dict__ = d
self.sp_model = load_spm(self.spm_file)
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
save_dir = Path(save_directory)
assert save_dir.is_dir(), f"{save_directory} should be a directory"
vocab_save_path = save_dir / (
(filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab_file"]
)
spm_save_path = save_dir / (
(filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["spm_file"]
)
save_json(self.encoder, vocab_save_path)
if not spm_save_path.exists():
copyfile(self.spm_file, spm_save_path)
return (str(vocab_save_path), str(spm_save_path))
def load_spm(path: str) -> sentencepiece.SentencePieceProcessor:
spm = sentencepiece.SentencePieceProcessor()
spm.Load(str(path))
return spm
def load_json(path: str) -> Union[Dict, List]:
with open(path, "r") as f:
return json.load(f)
def save_json(data, path: str) -> None:
with open(path, "w") as f:
json.dump(data, f, indent=2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment