Unverified Commit ac2bc50a authored by Matthijs Hollemans's avatar Matthijs Hollemans Committed by GitHub
Browse files

TTS fine-tuning for SpeechT5 (#21824)



* wrong argument name

* append eos_token_id

* all tokenizers need mask and ctc_blank tokens

* remove reduction factor from feature extractor

* add proper TTS loss

* did shifting the wrong way around

* mask out padded portions

* remove logits again (don't really need it)

* fix unit tests

* fixup

* pad also returns the decoder attention mask, since that's useful to have

* clean up feature extractor logic

* pad can handle TTS task too

* remove stop_labels from loss calculation

* simplify logic

* fixup

* do -100 masking properly

* small STFT optimization (calculate mel filterbanks only once)

* replace torchaudio fbanks with audio_utils

* remove torchaudio dependency

* simplify & speed up the STFT

* don't serialize window and mel filters

* output cross attentions when generating speech

* add guided attention loss

* fix failing test

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update src/transformers/models/speecht5/modeling_speecht5.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* change type annotation of attention_mask to LongTensor

* extract loss into class

* remove unused frame_signal_scale argument

* use config object in loss class

* fix type annotations in doc comments

* change optional to just bool

* implement missing tokenizer method

* add deprecation warning

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/models/speecht5/feature_extraction_speecht5.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* add deprecation warning for stop_labels

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent dacd3456
...@@ -161,13 +161,22 @@ class SpeechT5Config(PretrainedConfig): ...@@ -161,13 +161,22 @@ class SpeechT5Config(PretrainedConfig):
speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5): speech_decoder_postnet_dropout (`float`, *optional*, defaults to 0.5):
The dropout probability for the speech decoder post-net layers. The dropout probability for the speech decoder post-net layers.
reduction_factor (`int`, *optional*, defaults to 2): reduction_factor (`int`, *optional*, defaults to 2):
Spectrogram length reduction factor for the speech decoder post-net. Spectrogram length reduction factor for the speech decoder inputs.
max_speech_positions (`int`, *optional*, defaults to 4000): max_speech_positions (`int`, *optional*, defaults to 4000):
The maximum sequence length of speech features that this model might ever be used with. The maximum sequence length of speech features that this model might ever be used with.
max_text_positions (`int`, *optional*, defaults to 450): max_text_positions (`int`, *optional*, defaults to 450):
The maximum sequence length of text features that this model might ever be used with. The maximum sequence length of text features that this model might ever be used with.
encoder_max_relative_position (`int`, *optional*, defaults to 160): encoder_max_relative_position (`int`, *optional*, defaults to 160):
Maximum distance for relative position embedding in the encoder. Maximum distance for relative position embedding in the encoder.
use_guided_attention_loss (`bool`, *optional*, defaults to `True`):
Whether to apply guided attention loss while training the TTS model.
guided_attention_loss_num_heads (`int`, *optional*, defaults to 2):
Number of attention heads the guided attention loss will be applied to. Use -1 to apply this loss to all
attention heads.
guided_attention_loss_sigma (`float`, *optional*, defaults to 0.4):
Standard deviation for guided attention loss.
guided_attention_loss_scale (`float`, *optional*, defaults to 10.0):
Scaling coefficient for guided attention loss (also known as lambda).
use_cache (`bool`, *optional*, defaults to `True`): use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
...@@ -241,6 +250,10 @@ class SpeechT5Config(PretrainedConfig): ...@@ -241,6 +250,10 @@ class SpeechT5Config(PretrainedConfig):
max_speech_positions=4000, max_speech_positions=4000,
max_text_positions=450, max_text_positions=450,
encoder_max_relative_position=160, encoder_max_relative_position=160,
use_guided_attention_loss=True,
guided_attention_loss_num_heads=2,
guided_attention_loss_sigma=0.4,
guided_attention_loss_scale=10.0,
use_cache=True, use_cache=True,
is_encoder_decoder=True, is_encoder_decoder=True,
**kwargs, **kwargs,
...@@ -311,6 +324,12 @@ class SpeechT5Config(PretrainedConfig): ...@@ -311,6 +324,12 @@ class SpeechT5Config(PretrainedConfig):
self.max_speech_positions = max_speech_positions self.max_speech_positions = max_speech_positions
self.max_text_positions = max_text_positions self.max_text_positions = max_text_positions
self.encoder_max_relative_position = encoder_max_relative_position self.encoder_max_relative_position = encoder_max_relative_position
self.use_guided_attention_loss = use_guided_attention_loss
self.guided_attention_loss_num_heads = guided_attention_loss_num_heads
self.guided_attention_loss_sigma = guided_attention_loss_sigma
self.guided_attention_loss_scale = guided_attention_loss_scale
self.use_cache = use_cache self.use_cache = use_cache
self.is_encoder_decoder = is_encoder_decoder self.is_encoder_decoder = is_encoder_decoder
......
...@@ -351,12 +351,11 @@ def convert_speecht5_checkpoint( ...@@ -351,12 +351,11 @@ def convert_speecht5_checkpoint(
if vocab_path: if vocab_path:
tokenizer = SpeechT5Tokenizer(vocab_path, model_max_length=config.max_text_positions) tokenizer = SpeechT5Tokenizer(vocab_path, model_max_length=config.max_text_positions)
if task == "pretrain": # Mask token behaves like a normal word, i.e. include the space before it
# Mask token behaves like a normal word, i.e. include the space before it mask_token = AddedToken("<mask>", lstrip=True, rstrip=False)
mask_token = AddedToken("<mask>", lstrip=True, rstrip=False) tokenizer.mask_token = mask_token
tokenizer.mask_token = mask_token tokenizer.add_special_tokens({"mask_token": mask_token})
tokenizer.add_special_tokens({"mask_token": mask_token}) tokenizer.add_tokens(["<ctc_blank>"])
tokenizer.add_tokens(["<ctc_blank>"])
feature_extractor = SpeechT5FeatureExtractor() feature_extractor = SpeechT5FeatureExtractor()
processor = SpeechT5Processor(tokenizer=tokenizer, feature_extractor=feature_extractor) processor = SpeechT5Processor(tokenizer=tokenizer, feature_extractor=feature_extractor)
......
...@@ -14,12 +14,13 @@ ...@@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
"""Feature extractor class for SpeechT5.""" """Feature extractor class for SpeechT5."""
from typing import List, Optional, Union import warnings
from typing import Any, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torchaudio
from ...audio_utils import get_mel_filter_banks
from ...feature_extraction_sequence_utils import SequenceFeatureExtractor from ...feature_extraction_sequence_utils import SequenceFeatureExtractor
from ...feature_extraction_utils import BatchFeature from ...feature_extraction_utils import BatchFeature
from ...utils import PaddingStrategy, TensorType, logging from ...utils import PaddingStrategy, TensorType, logging
...@@ -60,7 +61,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -60,7 +61,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
win_function (`str`, *optional*, defaults to `"hann_window"`): win_function (`str`, *optional*, defaults to `"hann_window"`):
Name for the window function used for windowing, must be accessible via `torch.{win_function}` Name for the window function used for windowing, must be accessible via `torch.{win_function}`
frame_signal_scale (`float`, *optional*, defaults to 1.0): frame_signal_scale (`float`, *optional*, defaults to 1.0):
Constant multiplied in creating the frames before applying DFT. Constant multiplied in creating the frames before applying DFT. This argument is deprecated.
fmin (`float`, *optional*, defaults to 80): fmin (`float`, *optional*, defaults to 80):
Minimum mel frequency in Hz. Minimum mel frequency in Hz.
fmax (`float`, *optional*, defaults to 7600): fmax (`float`, *optional*, defaults to 7600):
...@@ -68,7 +69,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -68,7 +69,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
mel_floor (`float`, *optional*, defaults to 1e-10): mel_floor (`float`, *optional*, defaults to 1e-10):
Minimum value of mel frequency banks. Minimum value of mel frequency banks.
reduction_factor (`int`, *optional*, defaults to 2): reduction_factor (`int`, *optional*, defaults to 2):
Spectrogram length reduction factor. Spectrogram length reduction factor. This argument is deprecated.
return_attention_mask (`bool`, *optional*, defaults to `True`): return_attention_mask (`bool`, *optional*, defaults to `True`):
Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`. Whether or not [`~SpeechT5FeatureExtractor.__call__`] should return `attention_mask`.
""" """
...@@ -109,10 +110,33 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -109,10 +110,33 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
self.sample_size = win_length * sampling_rate // 1000 self.sample_size = win_length * sampling_rate // 1000
self.sample_stride = hop_length * sampling_rate // 1000 self.sample_stride = hop_length * sampling_rate // 1000
self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size))) self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
self.n_freqs = (self.n_fft // 2) + 1 self.n_freqs = (self.n_fft // 2) + 1
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
self.window = window.numpy().astype(np.float64)
self.mel_filters = get_mel_filter_banks(
nb_frequency_bins=self.n_freqs,
nb_mel_filters=self.num_mel_bins,
frequency_min=self.fmin,
frequency_max=self.fmax,
sample_rate=self.sampling_rate,
norm="slaney",
mel_scale="slaney",
)
if frame_signal_scale != 1.0:
warnings.warn(
"The argument `frame_signal_scale` is deprecated and will be removed in version 4.30.0 of Transformers",
FutureWarning,
)
if reduction_factor != 2.0:
warnings.warn(
"The argument `reduction_factor` is deprecated and will be removed in version 4.30.0 of Transformers",
FutureWarning,
)
@staticmethod @staticmethod
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm # Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
def zero_mean_unit_var_norm( def zero_mean_unit_var_norm(
...@@ -137,99 +161,45 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -137,99 +161,45 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
return normed_input_values return normed_input_values
@staticmethod @staticmethod
def _center_pad(one_waveform, n_fft, pad_mode): def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray:
padding = [(int(n_fft // 2), int(n_fft // 2))] """
return np.pad(one_waveform, padding, mode=pad_mode) Calculates the magnitude spectrogram over one waveform array.
"""
@staticmethod # center pad the waveform
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._num_frames_calc padding = [(int(fft_length // 2), int(fft_length // 2))]
def _num_frames_calc(in_size, frame_size, frame_stride): waveform = np.pad(waveform, padding, mode="reflect")
return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride)) waveform_size = waveform.size
@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._frame_signal
def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride):
scale = frame_signal_scale
frames = np.zeros(n_frames * window_length)
for frame_idx in range(n_frames):
start = frame_idx * window_length
end = (frame_idx + 1) * window_length
wave_start = frame_idx * sample_stride
wave_end = frame_idx * sample_stride + window_length
frames[start:end] = scale * one_waveform[wave_start:wave_end]
return frames
@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._windowing
def _windowing(frames, window_length, window):
if frames.size % window_length != 0:
raise ValueError(
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
f" window_length={window_length}."
)
shaped = frames.reshape(-1, window_length)
shaped = window * shaped
return shaped
@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._dft
def _dft(frames, K, n_frames, n_samples, n_fft):
dft = np.zeros([n_frames, K])
for frame in range(n_frames): # promote to float64, since np.fft uses float64 internally
begin = frame * n_samples waveform = waveform.astype(np.float64)
inwards_buffer = frames[begin : begin + n_samples] num_frames = int(1 + np.floor((waveform_size - fft_length) / hop_length))
inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant") num_frequency_bins = (fft_length // 2) + 1
out = np.fft.rfft(inwards_buffer) spectrogram = np.empty((num_frames, num_frequency_bins))
dft[frame] = np.abs(out[:K]) start = 0
for frame_idx in range(num_frames):
frame = waveform[start : start + fft_length] * window
spectrogram[frame_idx] = np.abs(np.fft.rfft(frame))
start += hop_length
return dft return spectrogram
def _extract_fbank_features( def _extract_mel_features(
self, self,
one_waveform: np.ndarray, one_waveform: np.ndarray,
) -> np.ndarray: ) -> np.ndarray:
""" """
Extracts log-mel filterbank features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC Extracts log-mel filterbank features for one waveform array (unbatched).
code and librosa.
""" """
one_waveform = self._center_pad(one_waveform, self.n_fft, "reflect") if self.n_fft != self.sample_size:
raise NotImplementedError(
n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride) f"Currently the STFT frame size must be a power of two, but got {self.sample_size} for a window length of {self.win_length} and sampling rate of {self.sampling_rate}. Ensure `win_length * sampling_rate // 1000` is divisible by two."
)
frames = self._frame_signal(
one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride
)
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
window = window.numpy()
frames = self._windowing(frames, self.sample_size, window)
dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)
fbanks = torchaudio.functional.melscale_fbanks(
n_freqs=self.n_freqs,
f_min=self.fmin,
f_max=self.fmax,
n_mels=self.num_mel_bins,
sample_rate=self.sampling_rate,
norm="slaney",
mel_scale="slaney",
)
fbanks = fbanks.numpy()
return np.log10(np.maximum(self.mel_floor, np.dot(dft_out, fbanks))) stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window)
def _reduce(self, inputs): return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters)))
reduced = []
for i in range(len(inputs)):
reduced.append(inputs[i][self.reduction_factor - 1 :: self.reduction_factor])
return reduced
def __call__( def __call__(
self, self,
...@@ -341,7 +311,6 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -341,7 +311,6 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
return inputs_target return inputs_target
else: else:
inputs["labels"] = inputs_target["input_values"] inputs["labels"] = inputs_target["input_values"]
inputs["stop_labels"] = inputs_target["stop_labels"]
decoder_attention_mask = inputs_target.get("attention_mask") decoder_attention_mask = inputs_target.get("attention_mask")
if decoder_attention_mask is not None: if decoder_attention_mask is not None:
inputs["decoder_attention_mask"] = decoder_attention_mask inputs["decoder_attention_mask"] = decoder_attention_mask
...@@ -381,8 +350,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -381,8 +350,7 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
# convert into correct format for padding # convert into correct format for padding
if is_target: if is_target:
features = [self._extract_fbank_features(waveform) for waveform in speech] features = [self._extract_mel_features(waveform) for waveform in speech]
fbank_sizes = [len(x) for x in features]
encoded_inputs = BatchFeature({"input_values": features}) encoded_inputs = BatchFeature({"input_values": features})
self.feature_size = self.num_mel_bins self.feature_size = self.num_mel_bins
else: else:
...@@ -429,22 +397,18 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor): ...@@ -429,22 +397,18 @@ class SpeechT5FeatureExtractor(SequenceFeatureExtractor):
padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value padded_inputs["input_values"], attention_mask=attention_mask, padding_value=self.padding_value
) )
if is_target:
# make labels for stop prediction
stop_labels = []
for i, l in enumerate(fbank_sizes):
labels = np.zeros(len(padded_inputs["input_values"][i]))
labels[l - 1 :] = 1.0
stop_labels.append(labels)
padded_inputs["stop_labels"] = stop_labels
# thin out frames for reduction factor
if self.reduction_factor > 1:
padded_inputs["input_values"] = self._reduce(padded_inputs["input_values"])
if attention_mask is not None:
padded_inputs["attention_mask"] = self._reduce(padded_inputs["attention_mask"])
if return_tensors is not None: if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors) padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
return padded_inputs return padded_inputs
def to_dict(self) -> Dict[str, Any]:
output = super().to_dict()
# Don't serialize these as they are derived from the other properties.
names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"]
for name in names:
if name in output:
del output[name]
return output
...@@ -87,59 +87,85 @@ class SpeechT5Processor(ProcessorMixin): ...@@ -87,59 +87,85 @@ class SpeechT5Processor(ProcessorMixin):
inputs = None inputs = None
if audio_target is not None: if audio_target is not None:
audio_target_features = self.feature_extractor( targets = self.feature_extractor(audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs)
audio_target=audio_target, *args, sampling_rate=sampling_rate, **kwargs labels = targets["input_values"]
) elif text_target is not None:
if inputs is None: targets = self.tokenizer(text_target, **kwargs)
return audio_target_features labels = targets["input_ids"]
else: else:
inputs["labels"] = audio_target_features["input_values"] targets = None
inputs["stop_labels"] = audio_target_features["stop_labels"]
decoder_attention_mask = audio_target_features.get("attention_mask") if inputs is None:
if decoder_attention_mask is not None: return targets
inputs["decoder_attention_mask"] = decoder_attention_mask
if targets is not None:
if text_target is not None: inputs["labels"] = labels
encodings_target = self.tokenizer(text_target, **kwargs)
if inputs is None: decoder_attention_mask = targets.get("attention_mask")
return encodings_target if decoder_attention_mask is not None:
else: inputs["decoder_attention_mask"] = decoder_attention_mask
inputs["labels"] = encodings_target["input_ids"]
decoder_attention_mask = encodings_target.get("attention_mask")
if decoder_attention_mask is not None:
inputs["decoder_attention_mask"] = decoder_attention_mask
return inputs return inputs
def pad(self, *args, **kwargs): def pad(self, *args, **kwargs):
""" """
This method forwards all its arguments to SpeechT5FeatureExtractor's [`~SpeechT5FeatureExtractor.pad`] and Collates the audio and text inputs, as well as their targets, into a padded batch.
returns its output.
Audio inputs are padded by SpeechT5FeatureExtractor's [`~SpeechT5FeatureExtractor.pad`]. Text inputs are padded
by SpeechT5Tokenizer's [`~SpeechT5Tokenizer.pad`].
You can process your labels by using the argument `text` (either in the same call as your audio inputs, or in a Valid input combinations are:
separate call). This forwards its arguments to SpeechT5Tokenizer's [`~SpeechT5Tokenizer.pad`].
- `input_ids` only
- `input_values` only
- `labels` only, either log-mel spectrograms or text tokens
- `input_ids` and log-mel spectrogram `labels`
- `input_values` and text `labels`
Please refer to the docstring of the above two methods for more information. Please refer to the docstring of the above two methods for more information.
""" """
input_features = kwargs.pop("input_features", None) input_values = kwargs.pop("input_values", None)
input_ids = kwargs.pop("input_ids", None)
labels = kwargs.pop("labels", None) labels = kwargs.pop("labels", None)
if len(args) > 0: if input_values is not None and input_ids is not None:
input_features = args[0] raise ValueError("Cannot process both `input_values` and `input_ids` inputs.")
args = args[1:] if input_values is None and input_ids is None and labels is None:
raise ValueError(
"You need to specify either an `input_values`, `input_ids`, or `labels` input to be padded."
)
if input_values is not None:
inputs = self.feature_extractor.pad(input_values, *args, **kwargs)
elif input_ids is not None:
inputs = self.tokenizer.pad(input_ids, **kwargs)
else:
inputs = None
if input_features is not None:
input_features = self.feature_extractor.pad(input_features, *args, **kwargs)
if labels is not None: if labels is not None:
labels = self.tokenizer.pad(labels, **kwargs) if "input_ids" in labels or (isinstance(labels, list) and "input_ids" in labels[0]):
targets = self.tokenizer.pad(labels, **kwargs)
if labels is None: labels = targets["input_ids"]
return input_features else:
elif input_features is None: feature_size_hack = self.feature_extractor.feature_size
return labels self.feature_extractor.feature_size = self.feature_extractor.num_mel_bins
targets = self.feature_extractor.pad(labels, *args, **kwargs)
self.feature_extractor.feature_size = feature_size_hack
labels = targets["input_values"]
else: else:
input_features["labels"] = labels["input_ids"] targets = None
return input_features
if inputs is None:
return targets
if targets is not None:
inputs["labels"] = labels
decoder_attention_mask = targets.get("attention_mask")
if decoder_attention_mask is not None:
inputs["decoder_attention_mask"] = decoder_attention_mask
return inputs
def batch_decode(self, *args, **kwargs): def batch_decode(self, *args, **kwargs):
""" """
......
...@@ -167,6 +167,26 @@ class SpeechT5Tokenizer(PreTrainedTokenizer): ...@@ -167,6 +167,26 @@ class SpeechT5Tokenizer(PreTrainedTokenizer):
out_string += self.sp_model.decode(current_sub_tokens) out_string += self.sp_model.decode(current_sub_tokens)
return out_string.strip() return out_string.strip()
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
"""Build model inputs from a sequence by appending eos_token_id."""
if token_ids_1 is None:
return token_ids_0 + [self.eos_token_id]
# We don't expect to process pairs, but leave the pair logic for API consistency
return token_ids_0 + token_ids_1 + [self.eos_token_id]
def get_special_tokens_mask(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
) -> List[int]:
if already_has_special_tokens:
return super().get_special_tokens_mask(
token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
)
suffix_ones = [1]
if token_ids_1 is None:
return ([0] * len(token_ids_0)) + suffix_ones
return ([0] * len(token_ids_0)) + ([0] * len(token_ids_1)) + suffix_ones
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not os.path.isdir(save_directory): if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory") logger.error(f"Vocabulary path ({save_directory}) should be a directory")
......
...@@ -21,7 +21,7 @@ import unittest ...@@ -21,7 +21,7 @@ import unittest
import numpy as np import numpy as np
from transformers import BatchFeature, is_speech_available from transformers import BatchFeature, is_speech_available
from transformers.testing_utils import require_torch, require_torchaudio from transformers.testing_utils import require_torch
from transformers.utils.import_utils import is_torch_available from transformers.utils.import_utils import is_torch_available
from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
...@@ -67,11 +67,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase): ...@@ -67,11 +67,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
hop_length=16, hop_length=16,
win_length=64, win_length=64,
win_function="hann_window", win_function="hann_window",
frame_signal_scale=1.0,
fmin=80, fmin=80,
fmax=7600, fmax=7600,
mel_floor=1e-10, mel_floor=1e-10,
reduction_factor=2,
return_attention_mask=True, return_attention_mask=True,
): ):
self.parent = parent self.parent = parent
...@@ -87,11 +85,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase): ...@@ -87,11 +85,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
self.hop_length = hop_length self.hop_length = hop_length
self.win_length = win_length self.win_length = win_length
self.win_function = win_function self.win_function = win_function
self.frame_signal_scale = frame_signal_scale
self.fmin = fmin self.fmin = fmin
self.fmax = fmax self.fmax = fmax
self.mel_floor = mel_floor self.mel_floor = mel_floor
self.reduction_factor = reduction_factor
self.return_attention_mask = return_attention_mask self.return_attention_mask = return_attention_mask
def prepare_feat_extract_dict(self): def prepare_feat_extract_dict(self):
...@@ -104,11 +100,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase): ...@@ -104,11 +100,9 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
"hop_length": self.hop_length, "hop_length": self.hop_length,
"win_length": self.win_length, "win_length": self.win_length,
"win_function": self.win_function, "win_function": self.win_function,
"frame_signal_scale": self.frame_signal_scale,
"fmin": self.fmin, "fmin": self.fmin,
"fmax": self.fmax, "fmax": self.fmax,
"mel_floor": self.mel_floor, "mel_floor": self.mel_floor,
"reduction_factor": self.reduction_factor,
"return_attention_mask": self.return_attention_mask, "return_attention_mask": self.return_attention_mask,
} }
...@@ -147,7 +141,6 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase): ...@@ -147,7 +141,6 @@ class SpeechT5FeatureExtractionTester(unittest.TestCase):
@require_torch @require_torch
@require_torchaudio
class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase): class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest.TestCase):
feature_extraction_class = SpeechT5FeatureExtractor if is_speech_available() else None feature_extraction_class = SpeechT5FeatureExtractor if is_speech_available() else None
...@@ -407,10 +400,10 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest ...@@ -407,10 +400,10 @@ class SpeechT5FeatureExtractionTest(SequenceFeatureExtractionTestMixin, unittest
def test_integration_target(self): def test_integration_target(self):
# fmt: off # fmt: off
EXPECTED_INPUT_VALUES = torch.tensor( EXPECTED_INPUT_VALUES = torch.tensor(
[-2.7713, -2.8896, -3.2619, -3.0843, -2.9919, -3.0084, -3.2796, -3.3169, [-2.6870, -3.0104, -3.1356, -3.5352, -3.0044, -3.0353, -3.4719, -3.6777,
-3.2397, -3.2053, -2.9151, -2.7921, -2.9403, -2.7411, -3.0654, -2.8314, -3.1520, -2.9435, -2.6553, -2.8795, -2.9944, -2.5921, -3.0279, -3.0386,
-3.0026, -2.9797, -3.1314, -2.9939, -2.6748, -2.7725, -2.8563, -2.9462, -3.0864, -3.1291, -3.2353, -2.7444, -2.6831, -2.7287, -3.1761, -3.1571,
-3.2623, -3.3044, -3.1318, -3.2672, -3.4030, -3.1988] -3.2726, -3.0582, -3.1007, -3.4533, -3.4695, -3.0998]
) )
# fmt: on # fmt: on
......
...@@ -25,10 +25,10 @@ from transformers.testing_utils import ( ...@@ -25,10 +25,10 @@ from transformers.testing_utils import (
require_sentencepiece, require_sentencepiece,
require_tokenizers, require_tokenizers,
require_torch, require_torch,
require_torchaudio,
slow, slow,
torch_device, torch_device,
) )
from transformers.trainer_utils import set_seed
from transformers.utils import cached_property from transformers.utils import cached_property
from ...test_configuration_common import ConfigTester from ...test_configuration_common import ConfigTester
...@@ -716,7 +716,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase): ...@@ -716,7 +716,6 @@ class SpeechT5ForSpeechToTextTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
@require_torchaudio
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
@slow @slow
...@@ -991,7 +990,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase): ...@@ -991,7 +990,6 @@ class SpeechT5ForTextToSpeechTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
@require_torchaudio
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
@slow @slow
...@@ -1005,11 +1003,13 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase): ...@@ -1005,11 +1003,13 @@ class SpeechT5ForTextToSpeechIntegrationTests(unittest.TestCase):
model.to(torch_device) model.to(torch_device)
processor = self.default_processor processor = self.default_processor
set_seed(555) # make deterministic
input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel" input_text = "mister quilter is the apostle of the middle classes and we are glad to welcome his gospel"
input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device) input_ids = processor(text=input_text, return_tensors="pt").input_ids.to(torch_device)
generated_speech = model.generate_speech(input_ids) generated_speech = model.generate_speech(input_ids)
self.assertEqual(generated_speech.shape, (1800, model.config.num_mel_bins)) self.assertEqual(generated_speech.shape, (1820, model.config.num_mel_bins))
@require_torch @require_torch
...@@ -1406,7 +1406,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase): ...@@ -1406,7 +1406,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
@require_torch @require_torch
@require_torchaudio
@require_sentencepiece @require_sentencepiece
@require_tokenizers @require_tokenizers
@slow @slow
......
...@@ -21,7 +21,7 @@ import unittest ...@@ -21,7 +21,7 @@ import unittest
from transformers import is_speech_available, is_torch_available from transformers import is_speech_available, is_torch_available
from transformers.models.speecht5 import SpeechT5Tokenizer from transformers.models.speecht5 import SpeechT5Tokenizer
from transformers.testing_utils import get_tests_dir, require_torch, require_torchaudio from transformers.testing_utils import get_tests_dir, require_torch
from transformers.utils import FEATURE_EXTRACTOR_NAME from transformers.utils import FEATURE_EXTRACTOR_NAME
...@@ -35,7 +35,6 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe_char.model") ...@@ -35,7 +35,6 @@ SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_bpe_char.model")
@require_torch @require_torch
@require_torchaudio
class SpeechT5ProcessorTest(unittest.TestCase): class SpeechT5ProcessorTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
...@@ -52,7 +51,6 @@ class SpeechT5ProcessorTest(unittest.TestCase): ...@@ -52,7 +51,6 @@ class SpeechT5ProcessorTest(unittest.TestCase):
"hop_length": 16, "hop_length": 16,
"win_length": 64, "win_length": 64,
"win_function": "hann_window", "win_function": "hann_window",
"frame_signal_scale": 1.0,
"fmin": 80, "fmin": 80,
"fmax": 7600, "fmax": 7600,
"mel_floor": 1e-10, "mel_floor": 1e-10,
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
from transformers import SPIECE_UNDERLINE from transformers import SPIECE_UNDERLINE
from transformers.models.speecht5 import SpeechT5Tokenizer from transformers.models.speecht5 import SpeechT5Tokenizer
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
from transformers.tokenization_utils import AddedToken
from ...test_tokenization_common import TokenizerTesterMixin from ...test_tokenization_common import TokenizerTesterMixin
...@@ -38,6 +39,12 @@ class SpeechT5TokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -38,6 +39,12 @@ class SpeechT5TokenizerTest(TokenizerTesterMixin, unittest.TestCase):
# We have a SentencePiece fixture for testing # We have a SentencePiece fixture for testing
tokenizer = SpeechT5Tokenizer(SAMPLE_VOCAB) tokenizer = SpeechT5Tokenizer(SAMPLE_VOCAB)
mask_token = AddedToken("<mask>", lstrip=True, rstrip=False)
tokenizer.mask_token = mask_token
tokenizer.add_special_tokens({"mask_token": mask_token})
tokenizer.add_tokens(["<ctc_blank>"])
tokenizer.save_pretrained(self.tmpdirname) tokenizer.save_pretrained(self.tmpdirname)
def get_input_output_texts(self, tokenizer): def get_input_output_texts(self, tokenizer):
...@@ -64,8 +71,10 @@ class SpeechT5TokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -64,8 +71,10 @@ class SpeechT5TokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(vocab_keys[0], "<s>") self.assertEqual(vocab_keys[0], "<s>")
self.assertEqual(vocab_keys[1], "<pad>") self.assertEqual(vocab_keys[1], "<pad>")
self.assertEqual(vocab_keys[-2], "œ") self.assertEqual(vocab_keys[-4], "œ")
self.assertEqual(len(vocab_keys), 79) self.assertEqual(vocab_keys[-2], "<mask>")
self.assertEqual(vocab_keys[-1], "<ctc_blank>")
self.assertEqual(len(vocab_keys), 81)
def test_vocab_size(self): def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 79) self.assertEqual(self.get_tokenizer().vocab_size, 79)
...@@ -175,7 +184,18 @@ class SpeechT5TokenizerTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -175,7 +184,18 @@ class SpeechT5TokenizerTest(TokenizerTesterMixin, unittest.TestCase):
] ]
# fmt: off # fmt: off
expected_encoding = {'input_ids': [[4, 32, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 64, 19, 8, 13, 18, 5, 13, 15, 22, 4, 28, 9, 8, 20, 9, 4, 7, 12, 4, 24, 22, 6, 8, 13, 17, 11, 39, 6, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 7, 9, 14, 4, 24, 22, 6, 8, 13, 17, 11, 39, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 39, 25, 5, 13, 6, 63, 4, 24, 13, 8, 27, 10, 14, 5, 12, 4, 21, 5, 9, 5, 13, 7, 15, 39, 24, 16, 13, 24, 8, 12, 5, 4, 7, 13, 17, 11, 10, 6, 5, 17, 6, 16, 13, 5, 12, 4, 64, 40, 47, 54, 32, 23, 4, 53, 49, 32, 23, 4, 54, 8, 40, 47, 54, 32, 7, 23, 4, 69, 52, 43, 23, 4, 51, 10, 12, 6, 10, 15, 40, 5, 13, 6, 23, 4, 69, 52, 48, 5, 6, 26, 26, 26, 63, 4, 19, 8, 13, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 61, 9, 14, 5, 13, 12, 6, 7, 9, 14, 10, 9, 21, 4, 64, 48, 52, 61, 63, 4, 7, 9, 14, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 53, 5, 9, 5, 13, 7, 6, 10, 8, 9, 4, 64, 48, 52, 53, 63, 4, 20, 10, 6, 11, 4, 8, 27, 5, 13, 4, 6, 11, 10, 13, 6, 22, 39, 6, 20, 8, 4, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 4, 18, 8, 14, 5, 15, 12, 4, 10, 9, 4, 8, 9, 5, 4, 11, 16, 9, 14, 13, 5, 14, 4, 24, 15, 16, 12, 4, 15, 7, 9, 21, 16, 7, 21, 5, 12, 4, 7, 9, 14, 4, 14, 5, 5, 24, 4, 10, 9, 6, 5, 13, 8, 24, 5, 13, 7, 25, 10, 15, 10, 6, 22, 4, 25, 5, 6, 20, 5, 5, 9, 4, 58, 7, 37, 23, 4, 49, 22, 32, 8, 13, 17, 11, 4, 7, 9, 14, 4, 32, 5, 9, 12, 8, 13, 55, 15, 8, 20, 26], [4, 40, 47, 54, 32, 4, 10, 12, 4, 14, 5, 12, 10, 21, 9, 5, 14, 4, 6, 8, 4, 24, 13, 5, 39, 6, 13, 7, 10, 9, 4, 14, 5, 5, 24, 4, 25, 10, 14, 10, 13, 5, 17, 6, 10, 8, 9, 7, 15, 4, 13, 5, 24, 13, 5, 12, 5, 9, 6, 7, 6, 10, 8, 9, 12, 4, 19, 13, 8, 18, 4, 16, 9, 15, 7, 25, 5, 15, 5, 14, 4, 6, 5, 37, 6, 4, 25, 22, 4, 46, 8, 10, 9, 6, 15, 22, 4, 17, 8, 9, 14, 10, 6, 10, 8, 9, 10, 9, 21, 4, 8, 9, 4, 25, 8, 6, 11, 4, 15, 5, 19, 6, 4, 7, 9, 14, 4, 13, 10, 21, 11, 6, 4, 17, 8, 9, 6, 5, 37, 6, 4, 10, 9, 4, 7, 15, 15, 4, 15, 7, 22, 5, 13, 12, 26, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [4, 32, 11, 5, 4, 45, 16, 10, 17, 28, 4, 25, 13, 8, 20, 9, 4, 19, 8, 37, 4, 46, 16, 18, 24, 12, 4, 8, 27, 5, 13, 4, 6, 11, 5, 4, 15, 7, 57, 22, 4, 14, 8, 21, 26, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]} expected_encoding = {
'input_ids': [
[4, 32, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 64, 19, 8, 13, 18, 5, 13, 15, 22, 4, 28, 9, 8, 20, 9, 4, 7, 12, 4, 24, 22, 6, 8, 13, 17, 11, 39, 6, 13, 7, 9, 12, 19, 8, 13, 18, 5, 13, 12, 4, 7, 9, 14, 4, 24, 22, 6, 8, 13, 17, 11, 39, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 39, 25, 5, 13, 6, 63, 4, 24, 13, 8, 27, 10, 14, 5, 12, 4, 21, 5, 9, 5, 13, 7, 15, 39, 24, 16, 13, 24, 8, 12, 5, 4, 7, 13, 17, 11, 10, 6, 5, 17, 6, 16, 13, 5, 12, 4, 64, 40, 47, 54, 32, 23, 4, 53, 49, 32, 23, 4, 54, 8, 40, 47, 54, 32, 7, 23, 4, 69, 52, 43, 23, 4, 51, 10, 12, 6, 10, 15, 40, 5, 13, 6, 23, 4, 69, 52, 48, 5, 6, 26, 26, 26, 63, 4, 19, 8, 13, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 61, 9, 14, 5, 13, 12, 6, 7, 9, 14, 10, 9, 21, 4, 64, 48, 52, 61, 63, 4, 7, 9, 14, 4, 48, 7, 6, 16, 13, 7, 15, 4, 52, 7, 9, 21, 16, 7, 21, 5, 4, 53, 5, 9, 5, 13, 7, 6, 10, 8, 9, 4, 64, 48, 52, 53, 63, 4, 20, 10, 6, 11, 4, 8, 27, 5, 13, 4, 6, 11, 10, 13, 6, 22, 39, 6, 20, 8, 4, 24, 13, 5, 6, 13, 7, 10, 9, 5, 14, 4, 18, 8, 14, 5, 15, 12, 4, 10, 9, 4, 8, 9, 5, 4, 11, 16, 9, 14, 13, 5, 14, 4, 24, 15, 16, 12, 4, 15, 7, 9, 21, 16, 7, 21, 5, 12, 4, 7, 9, 14, 4, 14, 5, 5, 24, 4, 10, 9, 6, 5, 13, 8, 24, 5, 13, 7, 25, 10, 15, 10, 6, 22, 4, 25, 5, 6, 20, 5, 5, 9, 4, 58, 7, 37, 23, 4, 49, 22, 32, 8, 13, 17, 11, 4, 7, 9, 14, 4, 32, 5, 9, 12, 8, 13, 55, 15, 8, 20, 26, 2],
[4, 40, 47, 54, 32, 4, 10, 12, 4, 14, 5, 12, 10, 21, 9, 5, 14, 4, 6, 8, 4, 24, 13, 5, 39, 6, 13, 7, 10, 9, 4, 14, 5, 5, 24, 4, 25, 10, 14, 10, 13, 5, 17, 6, 10, 8, 9, 7, 15, 4, 13, 5, 24, 13, 5, 12, 5, 9, 6, 7, 6, 10, 8, 9, 12, 4, 19, 13, 8, 18, 4, 16, 9, 15, 7, 25, 5, 15, 5, 14, 4, 6, 5, 37, 6, 4, 25, 22, 4, 46, 8, 10, 9, 6, 15, 22, 4, 17, 8, 9, 14, 10, 6, 10, 8, 9, 10, 9, 21, 4, 8, 9, 4, 25, 8, 6, 11, 4, 15, 5, 19, 6, 4, 7, 9, 14, 4, 13, 10, 21, 11, 6, 4, 17, 8, 9, 6, 5, 37, 6, 4, 10, 9, 4, 7, 15, 15, 4, 15, 7, 22, 5, 13, 12, 26, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[4, 32, 11, 5, 4, 45, 16, 10, 17, 28, 4, 25, 13, 8, 20, 9, 4, 19, 8, 37, 4, 46, 16, 18, 24, 12, 4, 8, 27, 5, 13, 4, 6, 11, 5, 4, 15, 7, 57, 22, 4, 14, 8, 21, 26, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
],
'attention_mask': [
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
}
# fmt: on # fmt: on
self.tokenizer_integration_test_util( self.tokenizer_integration_test_util(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment