Unverified Commit f16d29b3 authored by Steven Anton's avatar Steven Anton Committed by GitHub
Browse files

Adapt PerceiverIO Multimodal class to work with arbitrary modalities (#20054)



* * Properly register parameters in PerceiverMultimodalPreprocessor
* Adapt PerceiverTextPreprocessor to work with PerceiverMultimodalPreprocessor
* Change a few type hints

* Fix formatting; incorrect return type

* Return embeddings_wo_pos

---------
Co-authored-by: default avatarSteven Anton <antonstv@amazon.com>
parent c236a621
...@@ -877,7 +877,7 @@ class PerceiverModel(PerceiverPreTrainedModel): ...@@ -877,7 +877,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
# If no attention mask is provided, make them all ones # If no attention mask is provided, make them all ones
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length)), device=device) attention_mask = torch.ones((batch_size, seq_length), device=device)
# Make the attention mask broadcastable to [batch_size, num_heads, seq_length, seq_length] # Make the attention mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
extended_attention_mask = self.invert_attention_mask(attention_mask) extended_attention_mask = self.invert_attention_mask(attention_mask)
...@@ -911,7 +911,7 @@ class PerceiverModel(PerceiverPreTrainedModel): ...@@ -911,7 +911,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
"label": 1, "label": 1,
} }
else: else:
output_modality_sizes = None output_modality_sizes = modality_sizes
decoder_query = self.decoder.decoder_query( decoder_query = self.decoder.decoder_query(
inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points
) )
...@@ -952,7 +952,7 @@ class PerceiverModel(PerceiverPreTrainedModel): ...@@ -952,7 +952,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
@add_start_docstrings("""Example use of Perceiver for masked language modeling.""", PERCEIVER_START_DOCSTRING) @add_start_docstrings("""Example use of Perceiver for masked language modeling.""", PERCEIVER_START_DOCSTRING)
class PerceiverForMaskedLM(PerceiverPreTrainedModel): class PerceiverForMaskedLM(PerceiverPreTrainedModel):
def __init__(self, config): def __init__(self, config: PerceiverConfig):
super().__init__(config) super().__init__(config)
text_preprocessor = PerceiverTextPreprocessor(config) text_preprocessor = PerceiverTextPreprocessor(config)
...@@ -1777,7 +1777,7 @@ Note that, by masking the classification label during evaluation (i.e. simply pr ...@@ -1777,7 +1777,7 @@ Note that, by masking the classification label during evaluation (i.e. simply pr
PERCEIVER_START_DOCSTRING, PERCEIVER_START_DOCSTRING,
) )
class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel): class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
def __init__(self, config): def __init__(self, config: PerceiverConfig):
super().__init__(config) super().__init__(config)
n_audio_samples = config.num_frames * config.audio_samples_per_frame n_audio_samples = config.num_frames * config.audio_samples_per_frame
...@@ -2123,7 +2123,7 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder): ...@@ -2123,7 +2123,7 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
self.output_num_channels = output_num_channels self.output_num_channels = output_num_channels
# If `none`, the decoder will not construct any position encodings. # If `none`, the decoder will not construct any position encodings.
# You should construct your own when quering the decoder. # You should construct your own when querying the decoder.
self.output_position_encodings = None self.output_position_encodings = None
self.position_encoding_type = position_encoding_type self.position_encoding_type = position_encoding_type
self.position_encoding_kwargs = position_encoding_kwargs self.position_encoding_kwargs = position_encoding_kwargs
...@@ -2849,14 +2849,14 @@ class PerceiverTextPreprocessor(AbstractPreprocessor): ...@@ -2849,14 +2849,14 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
def num_channels(self) -> int: def num_channels(self) -> int:
return self.config.d_model return self.config.d_model
def forward(self, inputs: torch.LongTensor) -> torch.FloatTensor: def forward(self, inputs: torch.LongTensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
embeddings = self.embeddings(inputs) embeddings_without_pos = self.embeddings(inputs)
seq_length = inputs.shape[1] seq_length = inputs.shape[1]
position_ids = torch.arange(0, seq_length, device=inputs.device) position_ids = torch.arange(0, seq_length, device=inputs.device)
embeddings = embeddings + self.position_embeddings(position_ids) embeddings = embeddings_without_pos + self.position_embeddings(position_ids)
return embeddings, None, None return embeddings, None, embeddings_without_pos
class PerceiverEmbeddingDecoder(nn.Module): class PerceiverEmbeddingDecoder(nn.Module):
...@@ -2889,7 +2889,7 @@ class PerceiverMultimodalPostprocessor(nn.Module): ...@@ -2889,7 +2889,7 @@ class PerceiverMultimodalPostprocessor(nn.Module):
postprocessor. postprocessor.
Args: Args:
modalities (`Dict[str, PostprocessorType]`): modalities (`Mapping[str, PostprocessorType]`):
Dictionary mapping modality name to postprocessor class for that modality. Dictionary mapping modality name to postprocessor class for that modality.
input_is_dict (`bool`, *optional*, defaults to `False`): input_is_dict (`bool`, *optional*, defaults to `False`):
If True, input is assumed to be dictionary structured, and outputs keep the same dictionary shape. If If True, input is assumed to be dictionary structured, and outputs keep the same dictionary shape. If
...@@ -3345,7 +3345,7 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor): ...@@ -3345,7 +3345,7 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor):
of channels. of channels.
Args: Args:
modalities (`Dict[str, PreprocessorType]`): modalities (`Mapping[str, PreprocessorType]`):
Dict mapping modality name to preprocessor. Dict mapping modality name to preprocessor.
mask_probs (`Dict[str, float]`): mask_probs (`Dict[str, float]`):
Dict mapping modality name to masking probability of that modality. Dict mapping modality name to masking probability of that modality.
...@@ -3361,7 +3361,7 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor): ...@@ -3361,7 +3361,7 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor):
min_padding_size: int = 2, min_padding_size: int = 2,
): ):
super().__init__() super().__init__()
self.modalities = modalities self.modalities = nn.ModuleDict(modalities)
self.min_padding_size = min_padding_size self.min_padding_size = min_padding_size
self.mask_probs = mask_probs if mask_probs is not None else dict() self.mask_probs = mask_probs if mask_probs is not None else dict()
self.padding = nn.ParameterDict( self.padding = nn.ParameterDict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment