Unverified Commit 3b9174f2 authored by BHUVAN M's avatar BHUVAN M Committed by GitHub
Browse files

interpolation added for TVP. (#30863)

* Update TVP model to interpolate pre-trained image pad prompter encodings

* feat: Add 2D positional embeddings interpolation in TvpVisualInputEmbedding

* added required comments

* Update TVP model to interpolate pre-trained image pad prompter encodings

* feat: Add 2D positional embeddings interpolation in TvpVisualInputEmbedding

* added required comments

* docstring and argument fix

* doc fixes and test case fix suggested in review.

* varibale typo fix

* styling and name fixes for padding interpolation flag.
parent ea50b64b
...@@ -193,34 +193,81 @@ class TvpVisualInputEmbedding(nn.Module): ...@@ -193,34 +193,81 @@ class TvpVisualInputEmbedding(nn.Module):
self.token_type_embeddings = nn.Embedding(1, config.hidden_size) self.token_type_embeddings = nn.Embedding(1, config.hidden_size)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.max_grid_row_position_embeddings = config.max_grid_row_position_embeddings
self.max_grid_col_position_embeddings = config.max_grid_col_position_embeddings
def add_2d_positional_embeddings(self, grid): def interpolate_pos_encoding(self, embedding: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained pad weights , to be able to use the model on collection of high
resolution images (high resolution videos).
"""
h0 = w0 = 1
# if height dimension is to be interpolated
if height > self.max_grid_row_position_embeddings:
h0 = height / self.max_grid_row_position_embeddings
# if width dimension is to be interpolated
if width > self.max_grid_col_position_embeddings:
w0 = width / self.max_grid_col_position_embeddings
embedding = embedding.permute(0, 3, 1, 2) # (batch_size, hidden_dim, height, width)
embedding = nn.functional.interpolate(
embedding,
scale_factor=(h0, w0),
mode="bicubic",
align_corners=False,
)
embedding = embedding.permute(0, 2, 3, 1) # (batch_size, height, width, hidden_dim)
return embedding
def add_2d_positional_embeddings(self, grid, interpolate_pos_encoding: bool = False):
""" """
Args: Args:
grid: (batch_size, height, width, hidden_dim) grid: (batch_size, height, width, hidden_dim)
interpolate_pos_encoding: (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
Returns: Returns:
grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim) grid + col_position_embeddings.view(*col_shape): (batch_size, *, height, width, hidden_dim)
""" """
batch_size, height, width, hidden_dim = grid.shape batch_size, height, width, hidden_dim = grid.shape
# add row-wise position embeddings # add row-wise position embeddings
row_position_ids = torch.arange(height, dtype=torch.long, device=grid.device) # (height, ) # (height, )
row_position_embeddings = self.row_position_embeddings(row_position_ids) # (height, hidden_dim) row_height = min(self.max_grid_row_position_embeddings, height)
row_shape = (1,) * (len(grid.shape) - 3) + (height, 1, hidden_dim) # (1, height, 1, hidden_dim) row_position_ids = torch.arange(row_height, dtype=torch.long, device=grid.device)
grid = grid + row_position_embeddings.view(*row_shape) # broadcast automatically # (height, hidden_dim)
row_position_embeddings = self.row_position_embeddings(row_position_ids)
row_shape = (1,) * (len(grid.shape) - 3) + (row_height, 1, hidden_dim)
# (batch_size, height, 1, hidden_dim)
row_position_embeddings = row_position_embeddings.view(*row_shape)
# add column-wise position embeddings # add column-wise position embeddings
col_position_ids = torch.arange(width, dtype=torch.long, device=grid.device) # (width, ) row_width = min(self.max_grid_col_position_embeddings, width)
col_position_embeddings = self.col_position_embeddings(col_position_ids) # (width, hidden_dim) col_position_ids = torch.arange(row_width, dtype=torch.long, device=grid.device)
col_shape = (batch_size, 1, width, hidden_dim) # (1, 1, width, hidden_dim) # (width, hidden_dim)
return grid + col_position_embeddings.view(*col_shape) # broadcast automatically col_position_embeddings = self.col_position_embeddings(col_position_ids)
col_shape = (batch_size, 1, row_width, hidden_dim)
# (batch_size, 1, width, hidden_dim)
col_position_embeddings = col_position_embeddings.view(*col_shape)
# (batch_size, height, width, hidden_dim)
positional_embeddings = row_position_embeddings + col_position_embeddings
# This interpolation gets triggered ONLY when the input image dim is larger in any dimenstion than the original position embeddings
if interpolate_pos_encoding and (
height > self.max_grid_row_position_embeddings or width > self.max_grid_col_position_embeddings
):
grid = grid + self.interpolate_pos_encoding(positional_embeddings, height, width)
else:
grid = grid + positional_embeddings
return grid
def forward(self, grid): def forward(self, grid, interpolate_pos_encoding: bool = False):
""" """
Args: Args:
grid: Array of shape (batch_size, num_frames, height, width, num_channels). grid: Array of shape (batch_size, num_frames, height, width, num_channels).
It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note, It contains processed frames extracted from videos, and is generated by Tvp image preprocessor. Note,
num_frames can be 1 num_frames can be 1
interpolate_pos_encoding: (bool, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
Returns: Returns:
embeddings: The embedding of grid with size (batch_size, height*width, num_channels) embeddings: The embedding of grid with size (batch_size, height*width, num_channels)
...@@ -229,7 +276,7 @@ class TvpVisualInputEmbedding(nn.Module): ...@@ -229,7 +276,7 @@ class TvpVisualInputEmbedding(nn.Module):
batch_size, num_frames, height, width, num_channels = grid.shape batch_size, num_frames, height, width, num_channels = grid.shape
# temporal mean pooling, (batch_size, height, width, hidden_size) # temporal mean pooling, (batch_size, height, width, hidden_size)
grid = grid.mean(1) grid = grid.mean(1)
grid = self.add_2d_positional_embeddings(grid) grid = self.add_2d_positional_embeddings(grid, interpolate_pos_encoding=interpolate_pos_encoding)
# image token sequence, (batch_size, height*width, num_channels) # image token sequence, (batch_size, height*width, num_channels)
visual_tokens = grid.view(batch_size, -1, num_channels) visual_tokens = grid.view(batch_size, -1, num_channels)
visual_tokens_shape = visual_tokens.shape[:-1] visual_tokens_shape = visual_tokens.shape[:-1]
...@@ -586,6 +633,9 @@ TVP_INPUTS_DOCSTRING = r""" ...@@ -586,6 +633,9 @@ TVP_INPUTS_DOCSTRING = r"""
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained image pad prompter encodings and positional encodings.
""" """
...@@ -639,7 +689,6 @@ class TvpFramePadPrompter(nn.Module): ...@@ -639,7 +689,6 @@ class TvpFramePadPrompter(nn.Module):
self.num_frames = config.num_frames self.num_frames = config.num_frames
self.max_img_size = config.max_img_size self.max_img_size = config.max_img_size
self.visual_prompter_apply = config.visual_prompter_apply self.visual_prompter_apply = config.visual_prompter_apply
self.base_size = config.max_img_size - config.visual_prompt_size * 2 self.base_size = config.max_img_size - config.visual_prompt_size * 2
self.pad_up = nn.Parameter( self.pad_up = nn.Parameter(
torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size]) torch.randn([1, config.num_frames, 3, config.visual_prompt_size, config.max_img_size])
...@@ -670,19 +719,49 @@ class TvpFramePadPrompter(nn.Module): ...@@ -670,19 +719,49 @@ class TvpFramePadPrompter(nn.Module):
) )
) )
def forward(self, pixel_values): def interpolate_pad_encoding(self, prompt: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained pad weights, to be able to use the model on collection of high
resolution images (high resolution videos).
"""
# creates scale factor from height and width of original image wrt to the config.max_img_size
h0, w0 = height / self.max_img_size, width / self.max_img_size
batch, num_frames, channels, prompt_height, prompt_width = prompt.shape
# reshaping the batch and num_frames dimension into a single one (i.e (b,frames,c,h,w)-->(b*frames,c,h,w)), to apply bicubic interpolation
prompt = prompt.reshape(batch * num_frames, channels, prompt_height, prompt_width)
prompt = nn.functional.interpolate(
prompt,
scale_factor=(h0, w0),
mode="bicubic",
align_corners=False,
)
# reversing back to (batch,frames,channels,height,width), where height and width is the new interpolated height and width
prompt = prompt.reshape(batch, num_frames, channels, height, width)
return prompt
def forward(self, pixel_values, interpolate_pad_encoding: bool = False):
height, width = (
(pixel_values.shape[-2], pixel_values.shape[-1])
if interpolate_pad_encoding
else (self.max_img_size, self.max_img_size)
)
if self.visual_prompter_apply not in ("add", "remove", "replace"): if self.visual_prompter_apply not in ("add", "remove", "replace"):
raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}") raise ValueError(f"Invalid visual_prompter_apply value {self.visual_prompter_apply}")
if self.visual_prompter_apply in ("replace", "remove"): if self.visual_prompter_apply in ("replace", "remove"):
visual_prompt_mask = torch.ones( visual_prompt_mask = torch.ones([height, width], dtype=pixel_values.dtype, device=pixel_values.device)
[self.max_img_size, self.max_img_size], dtype=pixel_values.dtype, device=pixel_values.device
)
pixel_values *= visual_prompt_mask pixel_values *= visual_prompt_mask
if self.visual_prompter_apply in ("replace", "add"): if self.visual_prompter_apply in ("replace", "add"):
base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device) base = torch.zeros(1, self.num_frames, 3, self.base_size, self.base_size, device=pixel_values.device)
prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4) prompt = torch.cat([self.pad_left, base, self.pad_right], dim=4)
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3) prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=3)
prompt = torch.cat(pixel_values.size(0) * [prompt]) prompt = torch.cat(pixel_values.size(0) * [prompt])
if interpolate_pad_encoding:
prompt = self.interpolate_pad_encoding(prompt, height, width)
pixel_values = pixel_values + prompt.to(pixel_values.dtype) pixel_values = pixel_values + prompt.to(pixel_values.dtype)
return pixel_values return pixel_values
...@@ -738,6 +817,7 @@ class TvpModel(TvpPreTrainedModel): ...@@ -738,6 +817,7 @@ class TvpModel(TvpPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
): ):
r""" r"""
Returns: Returns:
...@@ -756,13 +836,17 @@ class TvpModel(TvpPreTrainedModel): ...@@ -756,13 +836,17 @@ class TvpModel(TvpPreTrainedModel):
>>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask) >>> output = model(text_inputs.input_ids, pixel_values, text_inputs.attention_mask)
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.return_dict
# Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features. # Add visual prompt, it compensates for the spatiotemporal information loss in 2D visual features.
pixel_values = self.vision_model(self.visual_prompter(pixel_values)) pixel_values = self.vision_model(
self.visual_prompter(pixel_values, interpolate_pad_encoding=interpolate_pos_encoding)
)
# (batch_size, sequence_length, hidden_size) # (batch_size, sequence_length, hidden_size)
text_embedding_output = self.embeddings(input_ids=input_ids) text_embedding_output = self.embeddings(input_ids=input_ids)
# (batch_size, visual_sequence_length, hidden_size) # (batch_size, visual_sequence_length, hidden_size)
visual_embedding_output = self.visual_embeddings(pixel_values) visual_embedding_output = self.visual_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
if attention_mask is not None: if attention_mask is not None:
# (batch_size, visual_sequence_length) # (batch_size, visual_sequence_length)
visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2]) visual_attention_mask = attention_mask.new_ones(visual_embedding_output.shape[:2])
...@@ -791,7 +875,6 @@ class TvpModel(TvpPreTrainedModel): ...@@ -791,7 +875,6 @@ class TvpModel(TvpPreTrainedModel):
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
if not return_dict: if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:] return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state, last_hidden_state=last_hidden_state,
pooler_output=pooled_output, pooler_output=pooled_output,
...@@ -841,6 +924,7 @@ class TvpForVideoGrounding(TvpPreTrainedModel): ...@@ -841,6 +924,7 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
): ):
r""" r"""
labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*): labels (`torch.FloatTensor` of shape `(batch_size, 3)`, *optional*):
...@@ -869,9 +953,9 @@ class TvpForVideoGrounding(TvpPreTrainedModel): ...@@ -869,9 +953,9 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
interpolate_pos_encoding=interpolate_pos_encoding,
) )
pooler_output = outputs[1] pooler_output = outputs[1]
logits = self.video_grounding_head(pooler_output) logits = self.video_grounding_head(pooler_output)
loss = None loss = None
...@@ -884,7 +968,6 @@ class TvpForVideoGrounding(TvpPreTrainedModel): ...@@ -884,7 +968,6 @@ class TvpForVideoGrounding(TvpPreTrainedModel):
+ self.config.distance_loss_weight * loss_dict["distance"] + self.config.distance_loss_weight * loss_dict["distance"]
+ self.config.duration_loss_weight * loss_dict["duration"] + self.config.duration_loss_weight * loss_dict["duration"]
) )
if not return_dict: if not return_dict:
outputs = (logits,) + outputs[2:] outputs = (logits,) + outputs[2:]
if loss is not None: if loss is not None:
......
...@@ -256,7 +256,7 @@ def prepare_img(): ...@@ -256,7 +256,7 @@ def prepare_img():
class TvpModelIntegrationTests(unittest.TestCase): class TvpModelIntegrationTests(unittest.TestCase):
@cached_property @cached_property
def default_image_processor(self): def default_image_processor(self):
return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp") if is_vision_available() else None return TvpImageProcessor.from_pretrained("Jiqing/tiny-random-tvp")
def test_inference_no_head(self): def test_inference_no_head(self):
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device) model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
...@@ -297,3 +297,41 @@ class TvpModelIntegrationTests(unittest.TestCase): ...@@ -297,3 +297,41 @@ class TvpModelIntegrationTests(unittest.TestCase):
assert outputs.logits.shape == expected_shape assert outputs.logits.shape == expected_shape
expected_slice = torch.tensor([[0.5061, 0.4988]]).to(torch_device) expected_slice = torch.tensor([[0.5061, 0.4988]]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits, expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits, expected_slice, atol=1e-4))
def test_interpolate_inference_no_head(self):
model = TvpModel.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img() # 480X640
encoding = image_processor(
images=image, return_tensors="pt", do_resize=False, do_pad=False, do_center_crop=False
)
input_ids = torch.tensor([[1, 2]])
attention_mask = torch.tensor([[1, 1]])
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
encoding.to(torch_device)
with torch.no_grad():
outputs = model(**encoding, interpolate_pos_encoding=True)
expected_shape = torch.Size((1, 1212, 128))
assert outputs.last_hidden_state.shape == expected_shape
def test_interpolate_inference_with_head(self):
model = TvpForVideoGrounding.from_pretrained("Jiqing/tiny-random-tvp").to(torch_device)
image_processor = self.default_image_processor
image = prepare_img() # 480X640
encoding = image_processor(
images=image, return_tensors="pt", do_resize=False, do_pad=False, do_center_crop=False
)
input_ids = torch.tensor([[1, 2]])
attention_mask = torch.tensor([[1, 1]])
encoding.update({"input_ids": input_ids, "attention_mask": attention_mask})
encoding.to(torch_device)
with torch.no_grad():
outputs = model(**encoding, interpolate_pos_encoding=True, output_hidden_states=True)
expected_shape = torch.Size((1, 1212, 128))
assert outputs.hidden_states[-1].shape == expected_shape
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment