Unverified Commit 0b0acc75 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Remove `head_mask` from Ultravox and Swin (#30764)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent af506fd7
...@@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module): ...@@ -102,7 +102,6 @@ class SwinSelfAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None, attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
) -> tuple[torch.Tensor, ...]: ) -> tuple[torch.Tensor, ...]:
batch_size, dim, num_channels = hidden_states.shape batch_size, dim, num_channels = hidden_states.shape
...@@ -201,12 +200,9 @@ class SwinAttention(nn.Module): ...@@ -201,12 +200,9 @@ class SwinAttention(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.FloatTensor | None = None, attention_mask: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
self_outputs = self.self( self_outputs = self.self(hidden_states, attention_mask, output_attentions)
hidden_states, attention_mask, head_mask, output_attentions
)
attention_output = self.output(self_outputs[0], hidden_states) attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] outputs = (attention_output,) + self_outputs[1:]
return outputs return outputs
...@@ -339,18 +335,14 @@ class SwinStage(nn.Module): ...@@ -339,18 +335,14 @@ class SwinStage(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_dimensions: tuple[int, int], input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
always_partition: bool | None = False, always_partition: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
height, width = input_dimensions height, width = input_dimensions
for i, layer_module in enumerate(self.blocks): for i, layer_module in enumerate(self.blocks):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
input_dimensions, input_dimensions,
layer_head_mask,
output_attentions, output_attentions,
always_partition, always_partition,
) )
...@@ -425,17 +417,13 @@ class SwinEncoder(nn.Module): ...@@ -425,17 +417,13 @@ class SwinEncoder(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_dimensions: tuple[int, int], input_dimensions: tuple[int, int],
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = False, output_attentions: bool | None = False,
always_partition: bool | None = False, always_partition: bool | None = False,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
for i, layer_module in enumerate(self.layers): for i, layer_module in enumerate(self.layers):
layer_head_mask = head_mask[i] if head_mask is not None else None
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
input_dimensions, input_dimensions,
layer_head_mask,
output_attentions, output_attentions,
always_partition, always_partition,
) )
...@@ -473,7 +461,6 @@ class SwinModel(nn.Module): ...@@ -473,7 +461,6 @@ class SwinModel(nn.Module):
def forward( def forward(
self, self,
pixel_values: torch.FloatTensor | None = None, pixel_values: torch.FloatTensor | None = None,
head_mask: torch.FloatTensor | None = None,
output_attentions: bool | None = None, output_attentions: bool | None = None,
) -> tuple[torch.Tensor]: ) -> tuple[torch.Tensor]:
embedding_output, input_dimensions = self.embeddings(pixel_values) embedding_output, input_dimensions = self.embeddings(pixel_values)
...@@ -481,7 +468,6 @@ class SwinModel(nn.Module): ...@@ -481,7 +468,6 @@ class SwinModel(nn.Module):
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
embedding_output, embedding_output,
input_dimensions, input_dimensions,
head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""PyTorch Ultravox model.""" """PyTorch Ultravox model."""
import copy import copy
import inspect
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from types import SimpleNamespace from types import SimpleNamespace
from typing import Annotated, Any, Literal, TypeAlias from typing import Annotated, Any, Literal, TypeAlias
...@@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin): ...@@ -380,11 +381,17 @@ class UltravoxTransformerProjector(nn.Module, ModuleUtilsMixin):
) )
hidden_states = hidden_states + positions hidden_states = hidden_states + positions
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for layer in self.layers: for layer in self.layers:
layer_outputs = layer( layer_outputs = layer(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
layer_head_mask=None, **kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
...@@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder): ...@@ -479,11 +486,17 @@ class ModifiedWhisperEncoder(WhisperEncoder):
attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states) attention_mask = self.get_attention_mask_by_audio_len(audio_lens, hidden_states)
# Backward compatibility for Transformers v4 where layer_head_mask
# was a required argument for WhisperEncoderLayer.forward
kwargs = {}
if "layer_head_mask" in inspect.signature(self.layers[0].forward).parameters:
kwargs["layer_head_mask"] = None
for encoder_layer in self.layers: for encoder_layer in self.layers:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask=None, **kwargs,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment