Unverified Commit 0b0acc75 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Remove `head_mask` from Ultravox and Swin (#30764)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent af506fd7
......@@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor, ...]:
batch_size, dim, num_channels = hidden_states.shape
......@@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
self,
hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
) -> tuple[torch.Tensor]:
self_outputs = self.self(
hidden_states, attention_mask, head_mask, output_attentions
)
self_outputs = self.self(hidden_states, attention_mask, output_attentions)
attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:]
return outputs
......@@ -339,18 +335,14 @@ class SwinStage(nn.Module):
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
height, width = input_dimensions
for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
......@@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
self,
hidden_states: torch.Tensor,
input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False,
always_partition: bool | None = False,
) -> tuple[torch.Tensor]:
for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module(
hidden_states,
input_dimensions,
layer_head_mask,
output_attentions,
always_partition,
)
......@@ -473,7 +461,6 @@ class SwinModel(nn.Module):
def forward(
self,
pixel_values: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = None,
) -> tuple[torch.Tensor]:
embedding_output, input_dimensions = self.embeddings(pixel_values)
......@@ -481,7 +468,6 @@ class SwinModel(nn.Module):
encoder_outputs = self.encoder(
embedding_output,
input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions,
)
......
......@@ -5,6 +5,7 @@
"""PyTorch Ultravox model."""
import copy
import inspect
from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias
......@@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
)
hidden_states = hidden_states + positions
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for layer in self.layers:
layer_outputs = layer(
hidden_states,
attention_mask=extended_attention_mask,
layer_head_mask=None,
**kwargs,
)
hidden_states = layer_outputs[0]
......@@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder):
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=None,
**kwargs,
)
hidden_states = layer_outputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment