Unverified Commit 84ad6af4 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

minor fixes (#14026)

parent f5af8736
......@@ -15,6 +15,7 @@
""" PyTorch CLIP model. """
from dataclasses import dataclass
from typing import Any, Optional, Tuple
import torch
......@@ -71,6 +72,7 @@ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
return (caption_loss + image_loss) / 2.0
@dataclass
class CLIPOutput(ModelOutput):
"""
Args:
......@@ -297,10 +299,9 @@ class CLIPEncoderLayer(nn.Module):
):
"""
Args:
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape :obj:`(seq_len, batch, embed_dim)`
hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape :obj:`(batch, seq_len, embed_dim)`
attention_mask (:obj:`torch.FloatTensor`): attention mask of size
:obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (:obj:`torch.FloatTensor`): mask for attention heads in a given layer of size
:obj:`(config.encoder_attention_heads,)`.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
......@@ -497,7 +498,6 @@ class CLIPEncoder(nn.Module):
Args:
config: CLIPConfig
embed_tokens (nn.Embedding): output embedding
"""
def __init__(self, config: CLIPConfig):
......@@ -517,7 +517,7 @@ class CLIPEncoder(nn.Module):
):
r"""
Args:
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
into associated vectors than the model's internal embedding lookup matrix.
......
......@@ -102,6 +102,7 @@ class CLIPVisionModelTester:
model = CLIPVisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
......@@ -350,6 +351,7 @@ class CLIPTextModelTester:
model = CLIPTextModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
......@@ -429,6 +431,7 @@ class CLIPModelTester:
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = CLIPModel(config).to(torch_device).eval()
with torch.no_grad():
result = model(input_ids, pixel_values, attention_mask)
self.parent.assertEqual(
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment