Unverified Commit b1160c0b authored by nilboy's avatar nilboy Committed by GitHub
Browse files

Fix wav2vec2 export onnx model with attention_mask error (#16004)

* Fix wav2vec2 export onnx model with attention_mask error

* fix repository_consistency
parent d91da4c6
...@@ -574,7 +574,8 @@ class Data2VecAudioEncoder(nn.Module): ...@@ -574,7 +574,8 @@ class Data2VecAudioEncoder(nn.Module):
if attention_mask is not None: if attention_mask is not None:
# make sure padded tokens output 0 # make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......
...@@ -660,7 +660,8 @@ class HubertEncoder(nn.Module): ...@@ -660,7 +660,8 @@ class HubertEncoder(nn.Module):
if attention_mask is not None: if attention_mask is not None:
# make sure padded tokens output 0 # make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
...@@ -748,7 +749,8 @@ class HubertEncoderStableLayerNorm(nn.Module): ...@@ -748,7 +749,8 @@ class HubertEncoderStableLayerNorm(nn.Module):
if attention_mask is not None: if attention_mask is not None:
# make sure padded tokens are not attended to # make sure padded tokens are not attended to
hidden_states[~attention_mask] = 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......
...@@ -697,7 +697,8 @@ class UniSpeechEncoder(nn.Module): ...@@ -697,7 +697,8 @@ class UniSpeechEncoder(nn.Module):
if attention_mask is not None: if attention_mask is not None:
# make sure padded tokens output 0 # make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
...@@ -785,7 +786,8 @@ class UniSpeechEncoderStableLayerNorm(nn.Module): ...@@ -785,7 +786,8 @@ class UniSpeechEncoderStableLayerNorm(nn.Module):
if attention_mask is not None: if attention_mask is not None:
# make sure padded tokens are not attended to # make sure padded tokens are not attended to
hidden_states[~attention_mask] = 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......
...@@ -711,7 +711,8 @@ class UniSpeechSatEncoder(nn.Module): ...@@ -711,7 +711,8 @@ class UniSpeechSatEncoder(nn.Module):
if attention_mask is not None: if attention_mask is not None:
# make sure padded tokens output 0 # make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
...@@ -799,7 +800,8 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): ...@@ -799,7 +800,8 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
if attention_mask is not None: if attention_mask is not None:
# make sure padded tokens are not attended to # make sure padded tokens are not attended to
hidden_states[~attention_mask] = 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......
...@@ -745,7 +745,8 @@ class Wav2Vec2Encoder(nn.Module): ...@@ -745,7 +745,8 @@ class Wav2Vec2Encoder(nn.Module):
if attention_mask is not None: if attention_mask is not None:
# make sure padded tokens output 0 # make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
...@@ -832,7 +833,8 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): ...@@ -832,7 +833,8 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
if attention_mask is not None: if attention_mask is not None:
# make sure padded tokens are not attended to # make sure padded tokens are not attended to
hidden_states[~attention_mask] = 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask # extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0 attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment