"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e231c729063d88b0c2bad4ac2d461c4e46eac9ab"
Unverified Commit 9caf68a6 authored by Alara Dirik's avatar Alara Dirik Committed by GitHub
Browse files

Owlvit test fixes (#18303)

* fix owlvit test assertion errors

* fix gpu test error

* remove redundant lines

* fix styling
parent 0077360d
...@@ -1170,6 +1170,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1170,6 +1170,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
if not feature_map.ndim == 4: if not feature_map.ndim == 4:
raise ValueError("Expected input shape is [batch_size, num_channels, height, width]") raise ValueError("Expected input shape is [batch_size, num_channels, height, width]")
device = feature_map.device
height, width = feature_map.shape[1:3] height, width = feature_map.shape[1:3]
box_coordinates = np.stack(np.meshgrid(np.arange(1, width + 1), np.arange(1, height + 1)), axis=-1).astype( box_coordinates = np.stack(np.meshgrid(np.arange(1, width + 1), np.arange(1, height + 1)), axis=-1).astype(
...@@ -1181,7 +1182,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1181,7 +1182,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
box_coordinates = box_coordinates.reshape( box_coordinates = box_coordinates.reshape(
box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2] box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
) )
box_coordinates = torch.from_numpy(box_coordinates) box_coordinates = torch.from_numpy(box_coordinates).to(device)
return box_coordinates return box_coordinates
...@@ -1285,7 +1286,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel): ...@@ -1285,7 +1286,7 @@ class OwlViTForObjectDetection(OwlViTPreTrainedModel):
self, self,
pixel_values: torch.FloatTensor, pixel_values: torch.FloatTensor,
input_ids: torch.Tensor, input_ids: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
......
...@@ -110,8 +110,7 @@ class OwlViTVisionModelTester: ...@@ -110,8 +110,7 @@ class OwlViTVisionModelTester:
) )
def create_and_check_model(self, config, pixel_values): def create_and_check_model(self, config, pixel_values):
model = OwlViTVisionModel(config=config) model = OwlViTVisionModel(config=config).to(torch_device)
model.to(torch_device)
model.eval() model.eval()
pixel_values = pixel_values.to(torch.float32) pixel_values = pixel_values.to(torch.float32)
...@@ -276,8 +275,7 @@ class OwlViTTextModelTester: ...@@ -276,8 +275,7 @@ class OwlViTTextModelTester:
) )
def create_and_check_model(self, config, input_ids, input_mask): def create_and_check_model(self, config, input_ids, input_mask):
model = OwlViTTextModel(config=config) model = OwlViTTextModel(config=config).to(torch_device)
model.to(torch_device)
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
result = model(input_ids=input_ids, attention_mask=input_mask) result = model(input_ids=input_ids, attention_mask=input_mask)
...@@ -455,8 +453,7 @@ class OwlViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -455,8 +453,7 @@ class OwlViTModelTest(ModelTesterMixin, unittest.TestCase):
configs_no_init.torchscript = True configs_no_init.torchscript = True
configs_no_init.return_dict = False configs_no_init.return_dict = False
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init).to(torch_device)
model.to(torch_device)
model.eval() model.eval()
try: try:
...@@ -479,10 +476,7 @@ class OwlViTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -479,10 +476,7 @@ class OwlViTModelTest(ModelTesterMixin, unittest.TestCase):
except Exception: except Exception:
self.fail("Couldn't load module.") self.fail("Couldn't load module.")
model.to(torch_device) loaded_model = loaded_model.to(torch_device)
model.eval()
loaded_model.to(torch_device)
loaded_model.eval() loaded_model.eval()
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
...@@ -638,8 +632,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase): ...@@ -638,8 +632,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):
configs_no_init.torchscript = True configs_no_init.torchscript = True
configs_no_init.return_dict = False configs_no_init.return_dict = False
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config=configs_no_init) model = model_class(config=configs_no_init).to(torch_device)
model.to(torch_device)
model.eval() model.eval()
try: try:
...@@ -662,10 +655,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase): ...@@ -662,10 +655,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):
except Exception: except Exception:
self.fail("Couldn't load module.") self.fail("Couldn't load module.")
model.to(torch_device) loaded_model = loaded_model.to(torch_device)
model.eval()
loaded_model.to(torch_device)
loaded_model.eval() loaded_model.eval()
model_state_dict = model.state_dict() model_state_dict = model.state_dict()
...@@ -720,8 +710,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase): ...@@ -720,8 +710,7 @@ class OwlViTForObjectDetectionTest(ModelTesterMixin, unittest.TestCase):
recursive_check(tuple_output, dict_output) recursive_check(tuple_output, dict_output)
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config).to(torch_device)
model.to(torch_device)
model.eval() model.eval()
tuple_inputs = self._prepare_for_class(inputs_dict, model_class) tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
...@@ -745,7 +734,7 @@ def prepare_img(): ...@@ -745,7 +734,7 @@ def prepare_img():
@require_vision @require_vision
@require_torch @require_torch
class OwlViTModelIntegrationTest(unittest.TestCase): class OwlViTModelIntegrationTest(unittest.TestCase):
@slow # @slow
def test_inference(self): def test_inference(self):
model_name = "google/owlvit-base-patch32" model_name = "google/owlvit-base-patch32"
model = OwlViTModel.from_pretrained(model_name).to(torch_device) model = OwlViTModel.from_pretrained(model_name).to(torch_device)
...@@ -767,24 +756,13 @@ class OwlViTModelIntegrationTest(unittest.TestCase): ...@@ -767,24 +756,13 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
# verify the logits # verify the logits
self.assertEqual( self.assertEqual(
outputs.logits_per_image.shape, outputs.logits_per_image.shape,
torch.Size( torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
(
inputs.pixel_values.shape[0],
inputs.input_ids.shape[0] * inputs.input_ids.shape[1] * inputs.pixel_values.shape[0],
)
),
) )
self.assertEqual( self.assertEqual(
outputs.logits_per_text.shape, outputs.logits_per_text.shape,
torch.Size( torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
(
inputs.input_ids.shape[0] * inputs.input_ids.shape[1] * inputs.pixel_values.shape[0],
inputs.pixel_values.shape[0],
)
),
) )
expected_logits = torch.tensor([[4.4420, 0.6181]], device=torch_device)
expected_logits = torch.tensor([[1.0115, 0.9982]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3)) self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
...@@ -810,6 +788,6 @@ class OwlViTModelIntegrationTest(unittest.TestCase): ...@@ -810,6 +788,6 @@ class OwlViTModelIntegrationTest(unittest.TestCase):
num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2) num_queries = int((model.config.vision_config.image_size / model.config.vision_config.patch_size) ** 2)
self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4))) self.assertEqual(outputs.pred_boxes.shape, torch.Size((1, num_queries, 4)))
expected_slice_boxes = torch.tensor( expected_slice_boxes = torch.tensor(
[[0.0143, 0.0236, 0.0285], [0.0649, 0.0247, 0.0437], [0.0601, 0.0446, 0.0699]] [[0.0948, 0.0471, 0.1915], [0.3194, 0.0583, 0.6498], [0.1441, 0.0452, 0.2197]]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4)) self.assertTrue(torch.allclose(outputs.pred_boxes[0, :3, :3], expected_slice_boxes, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment