Unverified Commit 67a3920d authored by Clémentine Fourrier's avatar Clémentine Fourrier Committed by GitHub
Browse files

Fix Graphormer test suite (#21419)

* [FIX] path for Graphormer checkpoint

* [FIX] Test suite for graphormer

* [FIX] Update graphormer default num_classes
parent e006ab51
...@@ -39,8 +39,8 @@ class GraphormerConfig(PretrainedConfig): ...@@ -39,8 +39,8 @@ class GraphormerConfig(PretrainedConfig):
Args: Args:
num_classes (`int`, *optional*, defaults to 2): num_classes (`int`, *optional*, defaults to 1):
Number of target classes or labels, set to 1 if the task is a regression task. Number of target classes or labels, set to n for binary classification of n tasks.
num_atoms (`int`, *optional*, defaults to 512*9): num_atoms (`int`, *optional*, defaults to 512*9):
Number of node types in the graphs. Number of node types in the graphs.
num_edges (`int`, *optional*, defaults to 512*3): num_edges (`int`, *optional*, defaults to 512*3):
...@@ -134,7 +134,7 @@ class GraphormerConfig(PretrainedConfig): ...@@ -134,7 +134,7 @@ class GraphormerConfig(PretrainedConfig):
def __init__( def __init__(
self, self,
num_classes: int = 2, num_classes: int = 1,
num_atoms: int = 512 * 9, num_atoms: int = 512 * 9,
num_edges: int = 512 * 3, num_edges: int = 512 * 3,
num_in_degree: int = 512, num_in_degree: int = 512,
......
...@@ -40,7 +40,7 @@ class GraphormerModelTester: ...@@ -40,7 +40,7 @@ class GraphormerModelTester:
def __init__( def __init__(
self, self,
parent, parent,
num_classes=2, num_classes=1,
num_atoms=512 * 9, num_atoms=512 * 9,
num_edges=512 * 3, num_edges=512 * 3,
num_in_degree=512, num_in_degree=512,
...@@ -614,7 +614,7 @@ class GraphormerModelIntegrationTest(unittest.TestCase): ...@@ -614,7 +614,7 @@ class GraphormerModelIntegrationTest(unittest.TestCase):
[3, 3, 4, 3, 3, 3, 3, 4, 4, 3, 4, 2, 2, 0, 0, 0, 0], [3, 3, 4, 3, 3, 3, 3, 4, 4, 3, 4, 2, 2, 0, 0, 0, 0],
] ]
), ),
"x": tensor( "input_nodes": tensor(
[ [
[[3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3]], [[3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3]],
[[3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [0], [0], [0], [0]], [[3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [3], [0], [0], [0], [0]],
...@@ -1279,15 +1279,11 @@ class GraphormerModelIntegrationTest(unittest.TestCase): ...@@ -1279,15 +1279,11 @@ class GraphormerModelIntegrationTest(unittest.TestCase):
output = model(**model_input)["logits"] output = model(**model_input)["logits"]
print(output.shape) expected_shape = torch.Size((2, 1))
print(output)
expected_shape = torch.Size(())
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
# TODO Replace values below with what was printed above. expected_logs = torch.tensor(
expected_slice = torch.tensor( [[7.6060], [7.4126]]
[[[-0.0483, 0.1188, -0.0313], [-0.0606, 0.1435, 0.0199], [-0.0235, 0.1519, 0.0175]]]
) )
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(output, expected_logs, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment