"tests/vscode:/vscode.git/clone" did not exist on "d1e3f489e909d351ea6abced05b0cdaf3a41be7a"
Unverified Commit 994b7a4e authored by Arthur's avatar Arthur Committed by GitHub
Browse files

update attention mask handling (#19385)

* update feature extractor params

* update attention mask handling
parent a26d71d6
...@@ -218,6 +218,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -218,6 +218,7 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
return_attention_mask: Optional[bool] = None, return_attention_mask: Optional[bool] = None,
padding: Optional[str] = "max_length", padding: Optional[str] = "max_length",
max_length: Optional[int] = None, max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
**kwargs **kwargs
) -> BatchFeature: ) -> BatchFeature:
""" """
...@@ -255,11 +256,25 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor): ...@@ -255,11 +256,25 @@ class WhisperFeatureExtractor(SequenceFeatureExtractor):
- `'np'`: Return Numpy `np.ndarray` objects. - `'np'`: Return Numpy `np.ndarray` objects.
sampling_rate (`int`, *optional*): sampling_rate (`int`, *optional*):
The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass The sampling rate at which the `raw_speech` input was sampled. It is strongly recommended to pass
`sampling_rate` at the forward call to prevent silent errors. `sampling_rate` at the forward call to prevent silent errors and allow automatic speech recognition
pipeline.
padding_value (`float`, defaults to 0.0): padding_value (`float`, defaults to 0.0):
The value that is used to fill the padding values / vectors. The value that is used to fill the padding values / vectors.
""" """
if sampling_rate is not None:
if sampling_rate != self.sampling_rate:
raise ValueError(
f"The model corresponding to this feature extractor: {self} was trained using a sampling rate of"
f" {self.sampling_rate}. Please make sure that the provided `raw_speech` input was sampled with"
f" {self.sampling_rate} and not {sampling_rate}."
)
else:
logger.warning(
"It is strongly recommended to pass the `sampling_rate` argument to this function. "
"Failing to do so can result in silent errors that might be hard to debug."
)
is_batched = bool( is_batched = bool(
isinstance(raw_speech, (list, tuple)) isinstance(raw_speech, (list, tuple))
and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list))) and (isinstance(raw_speech[0], np.ndarray) or isinstance(raw_speech[0], (tuple, list)))
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Dict, Optional, Union from typing import TYPE_CHECKING, Dict, Optional, Union
...@@ -259,9 +260,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -259,9 +260,9 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
# Currently chunking is not possible at this level for `seq2seq` so # Currently chunking is not possible at this level for `seq2seq` so
# it's ok. # it's ok.
align_to = self.model.config.inputs_to_logits_ratio align_to = self.model.config.inputs_to_logits_ratio
chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to)) * align_to chunk_len = int(round(chunk_length_s * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to)) * align_to stride_left = int(round(stride_length_s[0] * self.feature_extractor.sampling_rate / align_to) * align_to)
stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to)) * align_to stride_right = int(round(stride_length_s[1] * self.feature_extractor.sampling_rate / align_to) * align_to)
if self.type not in {"ctc", "ctc_with_lm"}: if self.type not in {"ctc", "ctc_with_lm"}:
raise ValueError( raise ValueError(
...@@ -304,12 +305,18 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline): ...@@ -304,12 +305,18 @@ class AutomaticSpeechRecognitionPipeline(ChunkPipeline):
f"`input_features` or `input_values` key, but only has {model_inputs.keys()}" f"`input_features` or `input_values` key, but only has {model_inputs.keys()}"
) )
accepts_attention_mask = "attention_mask" in set(inspect.signature(encoder.forward).parameters.keys())
if accepts_attention_mask:
attention_mask = model_inputs.pop("attention_mask", None) attention_mask = model_inputs.pop("attention_mask", None)
tokens = self.model.generate( tokens = self.model.generate(
encoder_outputs=encoder(inputs, attention_mask=attention_mask), encoder_outputs=encoder(inputs, attention_mask=attention_mask),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
else:
tokens = self.model.generate(inputs)
out = {"tokens": tokens} out = {"tokens": tokens}
else: else:
stride = model_inputs.pop("stride", None) stride = model_inputs.pop("stride", None)
input_values = model_inputs.pop("input_values") input_values = model_inputs.pop("input_values")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment