Unverified Commit 70f88eec authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix tapas issue (#12063)

* Fix scatter function to be compatible with torch-scatter 2.7.0

* Allow test again
parent e56e3140
......@@ -1697,9 +1697,9 @@ def _segment_reduce(values, index, segment_reduce_fn, name):
segment_means = scatter(
src=flat_values,
index=flat_index.indices.type(torch.long),
index=flat_index.indices.long(),
dim=0,
dim_size=flat_index.num_segments,
dim_size=int(flat_index.num_segments),
reduce=segment_reduce_fn,
)
......
......@@ -1044,7 +1044,6 @@ class TapasUtilitiesTest(unittest.TestCase):
# We use np.testing.assert_array_equal rather than Tensorflow's assertAllEqual
np.testing.assert_array_equal(maximum.numpy(), [2, 3])
@unittest.skip("Fix me I'm failing on CI")
def test_reduce_sum_vectorized(self):
values = torch.as_tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]])
index = IndexMap(indices=torch.as_tensor([0, 0, 1]), num_segments=2, batch_dims=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment