"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c02421883b2a59e075ec87de8d82f02a944fb5e8"
Unverified Commit d9b8d1a9 authored by Francesco Saverio Zuppichini's avatar Francesco Saverio Zuppichini Committed by GitHub
Browse files

update test (#16219)

parent 7e0d04be
......@@ -66,7 +66,9 @@ class MaskFormerModelTester:
self.mask_feature_size = mask_feature_size
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size])
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.min_size, self.max_size]).to(
torch_device
)
pixel_mask = torch.ones([self.batch_size, self.min_size, self.max_size], device=torch_device)
......@@ -232,12 +234,12 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_with_labels(self):
size = (self.model_tester.min_size,) * 2
inputs = {
"pixel_values": torch.randn((2, 3, *size)),
"mask_labels": torch.randn((2, 10, *size)),
"class_labels": torch.zeros(2, 10).long(),
"pixel_values": torch.randn((2, 3, *size), device=torch_device),
"mask_labels": torch.randn((2, 10, *size), device=torch_device),
"class_labels": torch.zeros(2, 10, device=torch_device).long(),
}
model = MaskFormerForInstanceSegmentation(MaskFormerConfig())
model = MaskFormerForInstanceSegmentation(MaskFormerConfig()).to(torch_device)
outputs = model(**inputs)
self.assertTrue(outputs.loss is not None)
......@@ -249,7 +251,7 @@ class MaskFormerModelTest(ModelTesterMixin, unittest.TestCase):
config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model = model_class(config).to(torch_device)
outputs = model(**inputs, output_attentions=True)
self.assertTrue(outputs.attentions is not None)
......@@ -381,7 +383,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
)
expected_slice = torch.tensor(
[[-1.3738, -1.7725, -1.9365], [-1.5978, -1.9869, -2.1524], [-1.5796, -1.9271, -2.0940]]
)
).to(torch_device)
self.assertTrue(torch.allclose(masks_queries_logits[0, 0, :3, :3], expected_slice, atol=TOLERANCE))
# class_queries_logits
class_queries_logits = outputs.class_queries_logits
......@@ -392,7 +394,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
[3.6169e-02, -5.9025e00, -2.9313e00],
[1.0766e-04, -7.7630e00, -5.1263e00],
]
)
).to(torch_device)
self.assertTrue(torch.allclose(outputs.class_queries_logits[0, :3, :3], expected_slice, atol=TOLERANCE))
def test_with_annotations_and_loss(self):
......@@ -406,7 +408,7 @@ class MaskFormerModelIntegrationTest(unittest.TestCase):
{"masks": np.random.rand(10, 384, 384).astype(np.float32), "labels": np.zeros(10).astype(np.int64)},
],
return_tensors="pt",
)
).to(torch_device)
with torch.no_grad():
outputs = model(**inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment