Unverified Commit aefc59f0 authored by AllenDou's avatar AllenDou Committed by GitHub
Browse files

FunASR model bugfix (#36633)


Signed-off-by: default avatarzixiao <shunli.dsl@alibaba-inc.com>
Co-authored-by: default avatarzixiao <shunli.dsl@alibaba-inc.com>
parent d88f28da
...@@ -573,6 +573,8 @@ class Transformer(nn.Module): ...@@ -573,6 +573,8 @@ class Transformer(nn.Module):
) )
def forward(self, hidden_states: torch.Tensor, ilens: int = 0): def forward(self, hidden_states: torch.Tensor, ilens: int = 0):
max_len = max(ilens)
hidden_states = hidden_states[:, :max_len, :]
batch_size, seq_len, dim = hidden_states.size() batch_size, seq_len, dim = hidden_states.size()
chunk_num = (seq_len - 1) // self.k + 1 chunk_num = (seq_len - 1) // self.k + 1
pad_num = chunk_num * self.k - seq_len pad_num = chunk_num * self.k - seq_len
......
...@@ -268,6 +268,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor): ...@@ -268,6 +268,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
n_fft=400, n_fft=400,
padding_value=0.0, padding_value=0.0,
dither=0.0, dither=0.0,
max_length=1000,
return_attention_mask=False, return_attention_mask=False,
**kwargs, **kwargs,
): ):
...@@ -279,6 +280,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor): ...@@ -279,6 +280,7 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
**kwargs, **kwargs,
) )
self.frontend_conf = kwargs.get("frontend_conf", {}) self.frontend_conf = kwargs.get("frontend_conf", {})
self.max_length = max_length
self.n_fft = n_fft self.n_fft = n_fft
self.hop_length = hop_length self.hop_length = hop_length
self.chunk_length = chunk_length self.chunk_length = chunk_length
...@@ -329,64 +331,41 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor): ...@@ -329,64 +331,41 @@ class FunASRFeatureExtractor(SequenceFeatureExtractor):
return_token_timestamps: bool | None = None, return_token_timestamps: bool | None = None,
**kwargs, **kwargs,
) -> BatchFeature: ) -> BatchFeature:
is_batched = isinstance(raw_speech, (list, tuple)) and ( frontend = WavFrontend(**self.frontend_conf, dither=self.dither)
isinstance(raw_speech[0], (np.ndarray, tuple, list))
)
if is_batched:
raw_speech = [
np.asarray([speech], dtype=np.float32).T for speech in raw_speech
]
elif not is_batched and not isinstance(raw_speech, np.ndarray):
raw_speech = np.asarray(raw_speech, dtype=np.float32)
elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(
np.float64
):
raw_speech = raw_speech.astype(np.float32)
if not is_batched:
raw_speech = [np.asarray([raw_speech]).T]
batched_speech = BatchFeature({"input_features": raw_speech})
padded_inputs = self.pad( feats = []
batched_speech, speech_lengths = []
fake_token_lengths = []
for speech in raw_speech:
feature, length = self.extract_fbank(
speech,
data_type=kwargs.get("data_type", "sound"),
frontend=frontend,
is_final=True,
)
feats.append(feature)
speech_lengths.append(length)
olens = 1 + (length - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2
fake_token_len = (olens - 1) // 2 + 1
fake_token_len = torch.clamp(fake_token_len, min=1)
fake_token_lengths.append(fake_token_len)
feats = torch.concat(feats, dim=0)
batched_speech = self.pad(
BatchFeature({"input_features": feats}),
padding=padding, padding=padding,
max_length=max_length if max_length else self.n_samples, max_length=max_length if max_length else self.max_length,
truncation=truncation, truncation=truncation,
pad_to_multiple_of=pad_to_multiple_of, pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=return_attention_mask or do_normalize, return_attention_mask=return_attention_mask or do_normalize,
) )
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
frontend = WavFrontend(**self.frontend_conf, dither=self.dither)
input_features, speech_lengths = self.extract_fbank(
input_features[0],
data_type=kwargs.get("data_type", "sound"),
frontend=frontend,
is_final=True,
)
olens = 1 + (speech_lengths - 3 + 2 * 1) // 2
olens = 1 + (olens - 3 + 2 * 1) // 2
fake_token_lengths = (olens - 1) // 2 + 1
if isinstance(input_features[0], list):
padded_inputs["input_features"] = [
np.asarray(feature, dtype=np.float32) for feature in input_features
]
else:
padded_inputs["input_features"] = input_features
if return_tensors is not None: if return_tensors is not None:
padded_inputs = padded_inputs.convert_to_tensors(return_tensors) batched_speech = batched_speech.convert_to_tensors(return_tensors)
fake_token_lengths = torch.clamp(fake_token_lengths, min=1)
padded_inputs["speech_lengths"] = speech_lengths
padded_inputs["fake_token_lengths"] = fake_token_lengths
return padded_inputs batched_speech["speech_lengths"] = torch.tensor(speech_lengths)
batched_speech["fake_token_lengths"] = torch.concat(fake_token_lengths)
return batched_speech
class FunASRProcessor(ProcessorMixin): class FunASRProcessor(ProcessorMixin):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment