Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
ce9a53d1
Commit
ce9a53d1
authored
Feb 20, 2018
by
rusty1s
Browse files
return C with batch
parent
6a2e1a08
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
5 additions
and
5 deletions
+5
-5
setup.py
setup.py
+1
-1
test/test_grid.py
test/test_grid.py
+2
-2
torch_cluster/__init__.py
torch_cluster/__init__.py
+1
-1
torch_cluster/functions/grid.py
torch_cluster/functions/grid.py
+1
-1
No files found.
setup.py
View file @
ce9a53d1
...
@@ -2,7 +2,7 @@ from os import path as osp
...
@@ -2,7 +2,7 @@ from os import path as osp
from
setuptools
import
setup
,
find_packages
from
setuptools
import
setup
,
find_packages
__version__
=
'0.2.
0
'
__version__
=
'0.2.
1
'
url
=
'https://github.com/rusty1s/pytorch_cluster'
url
=
'https://github.com/rusty1s/pytorch_cluster'
install_requires
=
[
'cffi'
,
'torch-unique'
]
install_requires
=
[
'cffi'
,
'torch-unique'
]
...
...
test/test_grid.py
View file @
ce9a53d1
...
@@ -50,7 +50,7 @@ def test_grid_cluster_cpu(tensor):
...
@@ -50,7 +50,7 @@ def test_grid_cluster_cpu(tensor):
output
,
C
=
grid_cluster
(
position
,
size
,
batch
,
fake_nodes
=
True
)
output
,
C
=
grid_cluster
(
position
,
size
,
batch
,
fake_nodes
=
True
)
expected
=
torch
.
LongTensor
([
0
,
5
,
1
,
0
,
2
,
6
,
11
,
7
,
6
,
8
])
expected
=
torch
.
LongTensor
([
0
,
5
,
1
,
0
,
2
,
6
,
11
,
7
,
6
,
8
])
assert
output
.
tolist
()
==
expected
.
tolist
()
assert
output
.
tolist
()
==
expected
.
tolist
()
assert
C
==
6
assert
C
==
12
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
...
@@ -101,4 +101,4 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
...
@@ -101,4 +101,4 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
output
,
C
=
grid_cluster
(
position
,
size
,
batch
,
fake_nodes
=
True
)
output
,
C
=
grid_cluster
(
position
,
size
,
batch
,
fake_nodes
=
True
)
expected
=
torch
.
LongTensor
([
0
,
5
,
1
,
0
,
2
,
6
,
11
,
7
,
6
,
8
])
expected
=
torch
.
LongTensor
([
0
,
5
,
1
,
0
,
2
,
6
,
11
,
7
,
6
,
8
])
assert
output
.
cpu
().
tolist
()
==
expected
.
tolist
()
assert
output
.
cpu
().
tolist
()
==
expected
.
tolist
()
assert
C
==
6
assert
C
==
12
torch_cluster/__init__.py
View file @
ce9a53d1
from
.functions.grid
import
grid_cluster
from
.functions.grid
import
grid_cluster
__version__
=
'0.2.
0
'
__version__
=
'0.2.
1
'
__all__
=
[
'grid_cluster'
,
'__version__'
]
__all__
=
[
'grid_cluster'
,
'__version__'
]
torch_cluster/functions/grid.py
View file @
ce9a53d1
...
@@ -50,7 +50,7 @@ def grid_cluster(position, size, batch=None, origin=None, fake_nodes=False):
...
@@ -50,7 +50,7 @@ def grid_cluster(position, size, batch=None, origin=None, fake_nodes=False):
cluster
=
cluster
.
squeeze
(
dim
=-
1
)
cluster
=
cluster
.
squeeze
(
dim
=-
1
)
if
fake_nodes
:
if
fake_nodes
:
return
cluster
,
C
//
c_max
[
0
]
return
cluster
,
C
cluster
,
u
=
consecutive
(
cluster
)
cluster
,
u
=
consecutive
(
cluster
)
return
cluster
,
None
if
batch
is
None
else
(
u
/
(
C
//
c_max
[
0
])).
long
()
return
cluster
,
None
if
batch
is
None
else
(
u
/
(
C
//
c_max
[
0
])).
long
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment