Unverified Commit b1160c0b authored by nilboy's avatar nilboy Committed by GitHub
Browse files

Fix wav2vec2 export onnx model with attention_mask error (#16004)

* Fix wav2vec2 export onnx model with attention_mask error

* fix repository_consistency
parent d91da4c6
......@@ -574,7 +574,8 @@ class Data2VecAudioEncoder(nn.Module):
if attention_mask is not None:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......
......@@ -660,7 +660,8 @@ class HubertEncoder(nn.Module):
if attention_mask is not None:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......@@ -748,7 +749,8 @@ class HubertEncoderStableLayerNorm(nn.Module):
if attention_mask is not None:
# make sure padded tokens are not attended to
hidden_states[~attention_mask] = 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......
......@@ -697,7 +697,8 @@ class UniSpeechEncoder(nn.Module):
if attention_mask is not None:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......@@ -785,7 +786,8 @@ class UniSpeechEncoderStableLayerNorm(nn.Module):
if attention_mask is not None:
# make sure padded tokens are not attended to
hidden_states[~attention_mask] = 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......
......@@ -711,7 +711,8 @@ class UniSpeechSatEncoder(nn.Module):
if attention_mask is not None:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......@@ -799,7 +800,8 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
if attention_mask is not None:
# make sure padded tokens are not attended to
hidden_states[~attention_mask] = 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......
......@@ -745,7 +745,8 @@ class Wav2Vec2Encoder(nn.Module):
if attention_mask is not None:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......@@ -832,7 +833,8 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
if attention_mask is not None:
# make sure padded tokens are not attended to
hidden_states[~attention_mask] = 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
# extend attention_mask
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment