"tests/models/bart/test_modeling_bart.py" did not exist on "a573777901e662ec2e565be312ffaeedef6effec"
Unverified Commit 42d8dd87 authored by Yixiang Gao's avatar Yixiang Gao Committed by GitHub
Browse files

Perceiver interpolate position embedding (#30979)



* add test that currently fails

* test passed

* all perceiver passed

* fixup, style, quality, repo-consistency, all passed

* Apply suggestions from code review: default to False + compute sqrt once only
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* fix a minor bracket

* replace dim with self._num_channels

* add arguments to the rest preprocessors

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 5855afd1
...@@ -699,13 +699,24 @@ PERCEIVER_INPUTS_DOCSTRING = r""" ...@@ -699,13 +699,24 @@ PERCEIVER_INPUTS_DOCSTRING = r"""
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail. more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """
@add_start_docstrings( @add_start_docstrings(
"""The Perceiver: a scalable, fully attentional architecture.""", """The Perceiver: a scalable, fully attentional architecture.
<Tip>
Note that it's possible to fine-tune Perceiver on higher resolution images than the ones it has been trained on, by
setting `interpolate_pos_encoding` to `True` in the forward of the model. This will interpolate the pre-trained
position embeddings to the higher resolution.
</Tip>
""",
PERCEIVER_MODEL_START_DOCSTRING, PERCEIVER_MODEL_START_DOCSTRING,
) )
class PerceiverModel(PerceiverPreTrainedModel): class PerceiverModel(PerceiverPreTrainedModel):
...@@ -754,6 +765,7 @@ class PerceiverModel(PerceiverPreTrainedModel): ...@@ -754,6 +765,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, PerceiverModelOutput]: ) -> Union[Tuple, PerceiverModelOutput]:
r""" r"""
...@@ -857,7 +869,9 @@ class PerceiverModel(PerceiverPreTrainedModel): ...@@ -857,7 +869,9 @@ class PerceiverModel(PerceiverPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.input_preprocessor is not None: if self.input_preprocessor is not None:
inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(inputs) inputs, modality_sizes, inputs_without_pos = self.input_preprocessor(
inputs, interpolate_pos_encoding=interpolate_pos_encoding
)
else: else:
modality_sizes = None modality_sizes = None
inputs_without_pos = None inputs_without_pos = None
...@@ -1247,6 +1261,7 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel): ...@@ -1247,6 +1261,7 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
labels: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None, pixel_values: Optional[torch.Tensor] = None,
) -> Union[Tuple, PerceiverClassifierOutput]: ) -> Union[Tuple, PerceiverClassifierOutput]:
...@@ -1295,6 +1310,7 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel): ...@@ -1295,6 +1310,7 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
head_mask=head_mask, head_mask=head_mask,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict, return_dict=return_dict,
) )
logits = outputs.logits if return_dict else outputs[0] logits = outputs.logits if return_dict else outputs[0]
...@@ -2749,9 +2765,31 @@ class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding): ...@@ -2749,9 +2765,31 @@ class PerceiverTrainablePositionEncoding(PerceiverAbstractPositionEncoding):
def output_size(self, *args, **kwargs) -> int: def output_size(self, *args, **kwargs) -> int:
return self._num_channels return self._num_channels
def forward(self, batch_size: int) -> torch.Tensor: def interpolate_pos_encoding(self, position_embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
num_positions = position_embeddings.shape[0]
new_height = new_width = math.sqrt(num_positions)
position_embeddings = position_embeddings.reshape(
1, int(new_height), int(new_width), self._num_channels
).permute(0, 3, 1, 2)
position_embeddings = nn.functional.interpolate(
position_embeddings,
scale_factor=(height / new_height, width / new_width),
mode="bicubic",
align_corners=False,
)
position_embeddings = position_embeddings.reshape(1, self._num_channels, -1).permute(0, 2, 1).squeeze(0)
return position_embeddings
def forward(
self, batch_size: int, interpolate_pos_encoding: bool = False, input_size: torch.Size = None
) -> torch.Tensor:
position_embeddings = self.position_embeddings position_embeddings = self.position_embeddings
if interpolate_pos_encoding:
height, width = input_size
height, width = height + 0.1, width + 0.1
position_embeddings = self.interpolate_pos_encoding(position_embeddings, height, width)
if batch_size is not None: if batch_size is not None:
position_embeddings = position_embeddings.expand(batch_size, -1, -1) position_embeddings = position_embeddings.expand(batch_size, -1, -1)
return position_embeddings return position_embeddings
...@@ -2859,7 +2897,13 @@ class PerceiverTextPreprocessor(AbstractPreprocessor): ...@@ -2859,7 +2897,13 @@ class PerceiverTextPreprocessor(AbstractPreprocessor):
def num_channels(self) -> int: def num_channels(self) -> int:
return self.config.d_model return self.config.d_model
def forward(self, inputs: torch.LongTensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True): def forward(
self,
inputs: torch.LongTensor,
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
):
embeddings_without_pos = self.embeddings(inputs) embeddings_without_pos = self.embeddings(inputs)
seq_length = inputs.shape[1] seq_length = inputs.shape[1]
...@@ -3139,7 +3183,9 @@ class PerceiverImagePreprocessor(AbstractPreprocessor): ...@@ -3139,7 +3183,9 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
return inp_dim + pos_dim return inp_dim + pos_dim
def _build_network_inputs(self, inputs: torch.Tensor, network_input_is_1d: bool = True): def _build_network_inputs(
self, inputs: torch.Tensor, network_input_is_1d: bool = True, interpolate_pos_encoding: bool = False
):
""" """
Construct the final input, including position encoding. Construct the final input, including position encoding.
...@@ -3147,6 +3193,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor): ...@@ -3147,6 +3193,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
""" """
batch_size = inputs.shape[0] batch_size = inputs.shape[0]
input_size = inputs.shape[1:3]
index_dims = inputs.shape[1:-1] index_dims = inputs.shape[1:-1]
indices = np.prod(index_dims) indices = np.prod(index_dims)
...@@ -3156,7 +3203,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor): ...@@ -3156,7 +3203,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
# Construct the position encoding. # Construct the position encoding.
if self.position_encoding_type == "trainable": if self.position_encoding_type == "trainable":
pos_enc = self.position_embeddings(batch_size) pos_enc = self.position_embeddings(batch_size, interpolate_pos_encoding, input_size)
elif self.position_encoding_type == "fourier": elif self.position_encoding_type == "fourier":
pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype) pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
...@@ -3174,7 +3221,13 @@ class PerceiverImagePreprocessor(AbstractPreprocessor): ...@@ -3174,7 +3221,13 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
inputs_with_pos = inputs + pos_enc inputs_with_pos = inputs + pos_enc
return inputs_with_pos, inputs return inputs_with_pos, inputs
def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True): def forward(
self,
inputs: torch.Tensor,
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
):
if self.prep_type == "conv": if self.prep_type == "conv":
# Convnet image featurization. # Convnet image featurization.
# Downsamples spatially by a factor of 4 # Downsamples spatially by a factor of 4
...@@ -3218,7 +3271,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor): ...@@ -3218,7 +3271,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
else: else:
raise ValueError("Unsupported data format for conv1x1.") raise ValueError("Unsupported data format for conv1x1.")
inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d) inputs, inputs_without_pos = self._build_network_inputs(inputs, network_input_is_1d, interpolate_pos_encoding)
modality_sizes = None # Size for each modality, only needed for multimodal modality_sizes = None # Size for each modality, only needed for multimodal
return inputs, modality_sizes, inputs_without_pos return inputs, modality_sizes, inputs_without_pos
...@@ -3338,7 +3391,13 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor): ...@@ -3338,7 +3391,13 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor):
return inputs_with_pos, inputs return inputs_with_pos, inputs
def forward(self, inputs: torch.Tensor, pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True): def forward(
self,
inputs: torch.Tensor,
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
):
inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch]) inputs = torch.reshape(inputs, [inputs.shape[0], -1, self.samples_per_patch])
inputs, inputs_without_pos = self._build_network_inputs(inputs) inputs, inputs_without_pos = self._build_network_inputs(inputs)
...@@ -3391,7 +3450,11 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor): ...@@ -3391,7 +3450,11 @@ class PerceiverMultimodalPreprocessor(AbstractPreprocessor):
return common_channel_size return common_channel_size
def forward( def forward(
self, inputs: Mapping[str, torch.Tensor], pos: Optional[torch.Tensor] = None, network_input_is_1d: bool = True self,
inputs: Mapping[str, torch.Tensor],
pos: Optional[torch.Tensor] = None,
network_input_is_1d: bool = True,
interpolate_pos_encoding: bool = False,
) -> PreprocessorOutputType: ) -> PreprocessorOutputType:
padded = {} padded = {}
modality_sizes = {} modality_sizes = {}
......
...@@ -1031,3 +1031,23 @@ class PerceiverModelIntegrationTest(unittest.TestCase): ...@@ -1031,3 +1031,23 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
) )
self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4))
@slow
def test_inference_interpolate_pos_encoding(self):
image_processor = PerceiverImageProcessor(size={"height": 384, "width": 384})
model = PerceiverForImageClassificationLearned.from_pretrained("deepmind/vision-perceiver-learned")
model.to(torch_device)
# prepare inputs
image = prepare_img()
inputs = image_processor(image, return_tensors="pt").pixel_values.to(torch_device)
input_mask = None
# forward pass
with torch.no_grad():
outputs = model(inputs=inputs, attention_mask=input_mask, interpolate_pos_encoding=True)
logits = outputs.logits
# verify logits
expected_shape = torch.Size((1, model.config.num_labels))
self.assertEqual(logits.shape, expected_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment