Unverified Commit d9b8d1a9 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

update test (#16219)

parent 7e0d04be
...@@ -66,7 +66,9 @@ class MaskFormerModelTester: ...@@ -66,7 +66,9 @@ class MaskFormerModelTester:
self.mask_feature_size = mask_feature_size self.mask_feature_size = mask_feature_size
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]) pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to(
torch_device
)
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device) pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)
...@@ -232,12 +234,12 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -232,12 +234,12 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_with_labels(self): def test_model_with_labels(self):
size = (self.model_tester.min_size,) * 2 size = (self.model_tester.min_size,) * 2
inputs = { inputs = {
"pixel_values": torch.randn((2, 3, *size)), "pixel_values": torch.randn((2, 3, *size), device=torch_device),
"mask_labels": torch.randn((2, 10, *size)), "mask_labels": torch.randn((2, 10, *size), device=torch_device),
"class_labels": torch.zeros(2, 10).long(), "class_labels": torch.zeros(2, 10, device=torch_device).long(),
} }
model = MaskFormerForInstanceSegmentation(MaskFormerConfig()) model = MaskFormerForInstanceSegmentation(MaskFormerConfig()).to(torch_device)
outputs = model(**inputs) outputs = model(**inputs)
self.assertTrue(outputs.loss is not None) self.assertTrue(outputs.loss is not None)
...@@ -249,7 +251,7 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -249,7 +251,7 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common() config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
model = model_class(config) model = model_class(config).to(torch_device)
outputs = model(**inputs, output_attentions=True) outputs = model(**inputs, output_attentions=True)
self.assertTrue(outputs.attentions is not None) self.assertTrue(outputs.attentions is not None)
...@@ -381,7 +383,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -381,7 +383,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
) )
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[-1.3738, -1.7725, -1.9365], [-1.5978, -1.9869, -2.1524], [-1.5796, -1.9271, -2.0940]] [[-1.3738, -1.7725, -1.9365], [-1.5978, -1.9869, -2.1524], [-1.5796, -1.9271, -2.0940]]
) ).to(torch_device)
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits # class_queries_logits
class_queries_logits = outputs.class_queries_logits class_queries_logits = outputs.class_queries_logits
...@@ -392,7 +394,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -392,7 +394,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
[3.6169e-02, -5.9025e00, -2.9313e00], [3.6169e-02, -5.9025e00, -2.9313e00],
[1.0766e-04, -7.7630e00, -5.1263e00], [1.0766e-04, -7.7630e00, -5.1263e00],
] ]
) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_annotations_and_loss(self): def test_with_annotations_and_loss(self):
...@@ -406,7 +408,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase): ...@@ -406,7 +408,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)}, {"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
], ],
return_tensors="pt", return_tensors="pt",
) ).to(torch_device)
with torch.no_grad(): with torch.no_grad():
outputs = model(**inputs) outputs = model(**inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment