Unverified Commit e926ea2b authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Improve perceiver (#14750)

* First draft

* Improve docstring + clean up tests

* Remove unused code

* Add check in case one doesn't provide a preprocessor
parent 971e3666
......@@ -42,7 +42,8 @@ class PerceiverConfig(PretrainedConfig):
d_latents (:obj:`int`, `optional`, defaults to 1280):
Dimension of the latent embeddings.
d_model (:obj:`int`, `optional`, defaults to 768):
Dimension of the inputs.
Dimension of the inputs. Should only be provided in case [`PerceiverTextPreprocessor`] is used or no
preprocessor is provided.
num_blocks (:obj:`int`, `optional`, defaults to 1):
Number of blocks in the Transformer encoder.
num_self_attends_per_block (:obj:`int`, `optional`, defaults to 26):
......
......@@ -499,7 +499,7 @@ class PerceiverLayer(nn.Module):
class PerceiverEncoder(nn.Module):
"""The Perceiver Encoder: a scalable, fully attentional encoder."""
def __init__(self, config):
def __init__(self, config, kv_dim=None):
super().__init__()
self.config = config
......@@ -523,7 +523,7 @@ class PerceiverEncoder(nn.Module):
v_channels=config.v_channels,
num_heads=config.num_cross_attention_heads,
q_dim=config.d_latents,
kv_dim=config.d_model,
kv_dim=kv_dim,
widening_factor=config.cross_attention_widening_factor,
use_query_residual=config.use_query_residual,
)
......@@ -734,7 +734,9 @@ class PerceiverModel(PerceiverPreTrainedModel):
self.input_preprocessor = input_preprocessor
self.output_postprocessor = output_postprocessor
self.embeddings = PerceiverEmbeddings(config)
self.encoder = PerceiverEncoder(config)
self.encoder = PerceiverEncoder(
config, kv_dim=input_preprocessor.num_channels if input_preprocessor is not None else config.d_model
)
self.decoder = decoder
# Initialize weights and apply final processing
......@@ -782,16 +784,13 @@ class PerceiverModel(PerceiverPreTrainedModel):
else:
modality_sizes = None
inputs_without_pos = None
if inputs.size()[-1] != self.config.d_model:
raise ValueError(
f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. "
"Make sure to set config.d_model appropriately."
)
if inputs.size()[-1] != self.config.d_model:
raise ValueError(
f"Last dimension of the inputs: {inputs.size()[-1]} doesn't correspond to config.d_model: {self.config.d_model}. "
"Please update config.d_model appropriately."
)
else:
input_shape = inputs.size()
batch_size, seq_length, _ = input_shape
batch_size, seq_length, _ = inputs.size()
device = inputs.device
# If no attention mask is provided, make them all ones
......@@ -874,20 +873,22 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
def __init__(self, config):
super().__init__(config)
text_preprocessor = PerceiverTextPreprocessor(config)
trainable_position_encoding_kwargs_decoder = dict(
num_channels=config.d_model, index_dims=config.max_position_embeddings
num_channels=text_preprocessor.num_channels, index_dims=config.max_position_embeddings
)
self.perceiver = PerceiverModel(
config,
input_preprocessor=PerceiverTextPreprocessor(config),
input_preprocessor=text_preprocessor,
decoder=PerceiverBasicDecoder(
config,
output_num_channels=config.d_latents,
output_index_dims=config.max_position_embeddings, # we need to define the seq_len of the inputs beforehand
num_channels=config.d_model,
num_channels=text_preprocessor.num_channels,
qk_channels=8 * 32,
v_channels=config.d_model,
v_channels=text_preprocessor.num_channels,
num_heads=8,
use_query_residual=False,
final_project=False,
......@@ -1502,22 +1503,24 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
concat_pos=True, max_resolution=config.train_size, num_bands=64, sine_only=False
)
image_preprocessor = PerceiverImagePreprocessor(
config,
prep_type="patches",
spatial_downsample=1,
conv_after_patching=True,
conv_after_patching_in_channels=54,
temporal_downsample=2,
position_encoding_type="fourier",
# position_encoding_kwargs
fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
)
self.perceiver = PerceiverModel(
config,
input_preprocessor=PerceiverImagePreprocessor(
config,
prep_type="patches",
spatial_downsample=1,
conv_after_patching=True,
conv_after_patching_in_channels=54,
temporal_downsample=2,
position_encoding_type="fourier",
# position_encoding_kwargs
fourier_position_encoding_kwargs=fourier_position_encoding_kwargs_preprocessor,
),
input_preprocessor=image_preprocessor,
decoder=PerceiverOpticalFlowDecoder(
config,
num_channels=config.d_model,
num_channels=image_preprocessor.num_channels,
output_image_shape=config.train_size,
rescale_factor=100.0,
# decoder kwargs
......@@ -2631,6 +2634,7 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
def __init__(self, config):
super().__init__()
self.config = config
self.embeddings = nn.Embedding(num_embeddings=config.vocab_size, embedding_dim=config.d_model)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
......
......@@ -147,19 +147,14 @@ class PerceiverModelTester:
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
elif model_class.__name__ == "PerceiverForImageClassificationLearned":
config.d_model = 512
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
elif model_class.__name__ == "PerceiverForImageClassificationFourier":
config.d_model = 261
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
elif model_class.__name__ == "PerceiverForImageClassificationConvProcessing":
config.d_model = 322
inputs = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
elif model_class.__name__ == "PerceiverForOpticalFlow":
config.d_model = 322
inputs = floats_tensor([self.batch_size, 2, 27, self.train_size[0], self.train_size[1]])
elif model_class.__name__ == "PerceiverForMultimodalAutoencoding":
config.d_model = 409
images = torch.randn(
(self.batch_size, self.num_frames, self.num_channels, self.image_size, self.image_size),
device=torch_device,
......@@ -211,8 +206,6 @@ class PerceiverModelTester:
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_sequence_classification(self, config, inputs, input_mask, sequence_labels, token_labels):
# set num_labels
config.num_labels = self.num_labels
model = PerceiverForSequenceClassification(config=config)
model.to(torch_device)
model.eval()
......@@ -222,9 +215,6 @@ class PerceiverModelTester:
def create_and_check_for_image_classification_learned(
self, config, inputs, input_mask, sequence_labels, token_labels
):
# set d_model and num_labels
config.d_model = 512
config.num_labels = self.num_labels
model = PerceiverForImageClassificationLearned(config=config)
model.to(torch_device)
model.eval()
......@@ -234,9 +224,6 @@ class PerceiverModelTester:
def create_and_check_for_image_classification_fourier(
self, config, inputs, input_mask, sequence_labels, token_labels
):
# set d_model and num_labels
config.d_model = 261
config.num_labels = self.num_labels
model = PerceiverForImageClassificationFourier(config=config)
model.to(torch_device)
model.eval()
......@@ -246,9 +233,6 @@ class PerceiverModelTester:
def create_and_check_for_image_classification_conv(
self, config, inputs, input_mask, sequence_labels, token_labels
):
# set d_model and num_labels
config.d_model = 322
config.num_labels = self.num_labels
model = PerceiverForImageClassificationConvProcessing(config=config)
model.to(torch_device)
model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment