Unverified Commit 2cdac665 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix device issues in `CLIPSegModelIntegrationTest` (#20467)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 61d3928b
...@@ -731,11 +731,11 @@ class CLIPSegModelIntegrationTest(unittest.TestCase): ...@@ -731,11 +731,11 @@ class CLIPSegModelIntegrationTest(unittest.TestCase):
) )
expected_masks_slice = torch.tensor( expected_masks_slice = torch.tensor(
[[-7.4577, -7.4952, -7.4072], [-7.3115, -7.0969, -7.1624], [-6.9472, -6.7641, -6.8911]] [[-7.4577, -7.4952, -7.4072], [-7.3115, -7.0969, -7.1624], [-6.9472, -6.7641, -6.8911]]
) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3)) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3))
# verify conditional and pooled output # verify conditional and pooled output
expected_conditional = torch.tensor([0.5601, -0.0314, 0.1980]) expected_conditional = torch.tensor([0.5601, -0.0314, 0.1980]).to(torch_device)
expected_pooled_output = torch.tensor([0.2692, -0.7197, -0.1328]) expected_pooled_output = torch.tensor([0.2692, -0.7197, -0.1328]).to(torch_device)
self.assertTrue(torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3)) self.assertTrue(torch.allclose(outputs.conditional_embeddings[0, :3], expected_conditional, atol=1e-3))
self.assertTrue(torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3)) self.assertTrue(torch.allclose(outputs.pooled_output[0, :3], expected_pooled_output, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment