"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4fc9f9ef54e2ab250042c55b55a2e3c097858cb7"
Unverified Commit 5be46dfc authored by Marco Carosi's avatar Marco Carosi Committed by GitHub
Browse files

[Whisper] Fix errors with MPS backend introduced by new code on word-level...


[Whisper] Fix errors with MPS backend introduced by new code on word-level timestamps computation (#28288)

* Update modeling_whisper.py to support MPS backend

Fixed some issue with MPS backend.

First, the torch.std_mean is not implemented and is not scheduled for implementation, while the single torch.std and torch.mean are.
Second, MPS backend does not support float64, so it can not cast from float32 to float64. Inverting the double() when the matrix is in the cpu fixes the issue while should not change the logic.

* Found another instruction in modeling_whisper.py not implemented byor MPS

After a load test, where I transcribed a 2 hours audio file, I got into a branch that did not fix in the previous commit.
Similar fix, where the torch.std_mean is changed into torch.std and torch.mean

* Update modeling_whisper.py removed trailing white spaces

Removed trailing white spaces

* Update modeling_whisper.py to use is_torch_mps_available()

Using is_torch_mps_available() instead of capturing the NotImplemented exception

* Update modeling_whisper.py sorting the import block

Sorting the utils import block

* Update src/transformers/models/whisper/modeling_whisper.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_whisper.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/whisper/modeling_whisper.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 87ae2a46
...@@ -2599,7 +2599,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -2599,7 +2599,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
if num_frames is None or isinstance(num_frames, int): if num_frames is None or isinstance(num_frames, int):
# Normalize and smoothen the weights. # Normalize and smoothen the weights.
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False) std = torch.std(weights, dim=-2, keepdim=True, unbiased=False)
mean = torch.mean(weights, dim=-2, keepdim=True)
weights = (weights - mean) / std weights = (weights - mean) / std
weights = _median_filter(weights, self.config.median_filter_width) weights = _median_filter(weights, self.config.median_filter_width)
...@@ -2608,11 +2609,12 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -2608,11 +2609,12 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
# Perform dynamic time warping on each element of the batch. # Perform dynamic time warping on each element of the batch.
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
if num_frames is not None and isinstance(num_frames, (tuple, list)): if num_frames is not None and isinstance(num_frames, (tuple, list, np.ndarray)):
matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2] matrix = weights[batch_idx, ..., : num_frames[batch_idx] // 2]
# Normalize and smoothen the weights. # Normalize and smoothen the weights.
std, mean = torch.std_mean(matrix, dim=-2, keepdim=True, unbiased=False) std = torch.std(matrix, dim=-2, keepdim=True, unbiased=False)
mean = torch.mean(matrix, dim=-2, keepdim=True)
matrix = (matrix - mean) / std matrix = (matrix - mean) / std
matrix = _median_filter(matrix, self.config.median_filter_width) matrix = _median_filter(matrix, self.config.median_filter_width)
...@@ -2621,7 +2623,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -2621,7 +2623,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
else: else:
matrix = weights[batch_idx] matrix = weights[batch_idx]
text_indices, time_indices = _dynamic_time_warping(-matrix.double().cpu().numpy()) text_indices, time_indices = _dynamic_time_warping(-matrix.cpu().double().numpy())
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool) jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] * time_precision jump_times = time_indices[jumps] * time_precision
timestamps[batch_idx, 1:] = torch.tensor(jump_times) timestamps[batch_idx, 1:] = torch.tensor(jump_times)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment