Unverified Commit f16d29b3 authored by Steven Anton's avatar Steven Anton Committed by GitHub
Browse files

Adapt PerceiverIO Multimodal class to work with arbitrary modalities (#20054)



* * Properly register parameters in PerceiverMultimodalPreprocessor
* Adapt PerceiverTextPreprocessor to work with PerceiverMultimodalPreprocessor
* Change a few type hints

* Fix formatting; incorrect return type

* Return embeddings_wo_pos

---------
Co-authored-by: default avatarSteven Anton <antonstv@amazon.com>
parent c236a621
......@@ -877,7 +877,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
# If no attention mask is provided, make them all ones
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length)), device=device)
attention_mask = torch.ones((batch_size, seq_length), device=device)
# Make the attention mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
extended_attention_mask = self.invert_attention_mask(attention_mask)
......@@ -911,7 +911,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
"label": 1,
}
else:
output_modality_sizes = None
output_modality_sizes = modality_sizes
decoder_query = self.decoder.decoder_query(
inputs, modality_sizes, inputs_without_pos, subsampled_points=subsampled_output_points
)
......@@ -952,7 +952,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
@add_start_docstrings("""Example use of Perceiver for masked language modeling.""", PERCEIVER_START_DOCSTRING)
class PerceiverForMaskedLM(PerceiverPreTrainedModel):
def __init__(self, config):
def __init__(self, config: PerceiverConfig):
super().__init__(config)
text_preprocessor = PerceiverTextPreprocessor(config)
......@@ -1777,7 +1777,7 @@ Note that, by masking the classification label during evaluation (i.e. simply pr
PERCEIVER_START_DOCSTRING,
)
class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
def __init__(self, config):
def __init__(self, config: PerceiverConfig):
super().__init__(config)
n_audio_samples = config.num_frames * config.audio_samples_per_frame
......@@ -2123,7 +2123,7 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
self.output_num_channels = output_num_channels
# If `none`, the decoder will not construct any position encodings.
# You should construct your own when quering the decoder.
# You should construct your own when querying the decoder.
self.output_position_encodings = None
self.position_encoding_type = position_encoding_type
self.position_encoding_kwargs = position_encoding_kwargs
......@@ -2849,14 +2849,14 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
def num_channels(self) -> int:
return self.config.d_model
def forward(self, inputs: torch.LongTensor) -> torch.FloatTensor:
embeddings = self.embeddings(inputs)
def forward(self, inputs: torch.LongTensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True):
embeddings_without_pos = self.embeddings(inputs)
seq_length = inputs.shape[1]
position_ids = torch.arange(0, seq_length, device=inputs.device)
embeddings = embeddings + self.position_embeddings(position_ids)
embeddings = embeddings_without_pos + self.position_embeddings(position_ids)
return embeddings, None, None
return embeddings, None, embeddings_without_pos
class PerceiverEmbeddingDecoder(nn.Module):
......@@ -2889,7 +2889,7 @@ class PerceiverMultimodalPostprocessor(nn.Module):
postprocessor.
Args:
modalities (`Dict[str, PostprocessorType]`):
modalities (`Mapping[str, PostprocessorType]`):
Dictionary mapping modality name to postprocessor class for that modality.
input_is_dict (`bool`, *optional*, defaults to `False`):
If True, input is assumed to be dictionary structured, and outputs keep the same dictionary shape. If
......@@ -3345,7 +3345,7 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor):
of channels.
Args:
modalities (`Dict[str, PreprocessorType]`):
modalities (`Mapping[str, PreprocessorType]`):
Dict mapping modality name to preprocessor.
mask_probs (`Dict[str, float]`):
Dict mapping modality name to masking probability of that modality.
......@@ -3361,7 +3361,7 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor):
min_padding_size: int = 2,
):
super().__init__()
self.modalities = modalities
self.modalities = nn.ModuleDict(modalities)
self.min_padding_size = min_padding_size
self.mask_probs = mask_probs if mask_probs is not None else dict()
self.padding = nn.ParameterDict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment