Unverified Commit 994b7a4e authored by Arthur's avatar Arthur Committed by GitHub
Browse files

update attention mask handling (#19385)

* update feature extractor params

* update attention mask handling
parent a26d71d6
......@@ -218,6 +218,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
return_attention_mask: Optional[bool] = None,
padding: Optional[str] = "max_length",
max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
**kwargs
) -> BatchFeature:
"""
......@@ -255,11 +256,25 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
- `'np'`: Return Numpy `np.ndarray` objects.
sampling_rate (`int`, *optional*):
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors.
`sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
pipeline.
padding_value (`float`, defaults to 0.0):
The value that is used to fill the padding values / vectors.
"""
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool(
isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Union
......@@ -259,9 +260,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# Currently chunking is not possible at this level for `seq2seq` so
# it's ok.
align_to = self.model.config.inputs_to_logits_ratio
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to)) * align_to
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to)) * align_to
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to)) * align_to
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError(
......@@ -304,12 +305,18 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
)
accepts_attention_mask = "attention_mask" in set(inspect.signature(encoder.forward).parameters.keys())
if accepts_attention_mask:
attention_mask = model_inputs.pop("attention_mask", None)
tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask,
)
else:
tokens = self.model.generate(inputs)
out = {"tokens": tokens}
else:
stride = model_inputs.pop("stride", None)
input_values = model_inputs.pop("input_values")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment