Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
37778e99
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "3a30bf56dcc4e8b2b8e4cb7f4615954acb883155"
Commit
37778e99
authored
Feb 06, 2018
by
rusty1s
Browse files
added new batch calculation
parent
f5812714
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
16 deletions
+20
-16
test/test_grid.py
test/test_grid.py
+12
-12
torch_cluster/functions/grid.py
torch_cluster/functions/grid.py
+6
-2
torch_cluster/functions/utils.py
torch_cluster/functions/utils.py
+2
-2
No files found.
test/test_grid.py
View file @
37778e99
...
@@ -26,13 +26,13 @@ def test_grid_cluster_cpu(tensor):
...
@@ -26,13 +26,13 @@ def test_grid_cluster_cpu(tensor):
output
=
grid_cluster
(
position
.
expand
(
2
,
5
,
2
),
size
)
output
=
grid_cluster
(
position
.
expand
(
2
,
5
,
2
),
size
)
assert
output
.
tolist
()
==
expected
.
expand
(
2
,
5
).
tolist
()
assert
output
.
tolist
()
==
expected
.
expand
(
2
,
5
).
tolist
()
expected
=
torch
.
LongTensor
([
0
,
1
,
3
,
2
,
4
])
position
=
position
.
repeat
(
2
,
1
)
batch
=
torch
.
LongTensor
([
0
,
0
,
1
,
1
,
1
])
batch
=
torch
.
LongTensor
([
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
])
output
=
grid_cluster
(
position
,
size
,
batch
)
expected
=
torch
.
LongTensor
([
0
,
3
,
1
,
0
,
2
,
4
,
7
,
5
,
4
,
6
])
expected_batch
=
torch
.
LongTensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
])
output
,
reduced_batch
=
grid_cluster
(
position
,
size
,
batch
)
assert
output
.
tolist
()
==
expected
.
tolist
()
assert
output
.
tolist
()
==
expected
.
tolist
()
assert
reduced_batch
.
tolist
()
==
expected_batch
.
tolist
()
output
=
grid_cluster
(
position
.
expand
(
2
,
5
,
2
),
size
,
batch
.
expand
(
2
,
5
))
assert
output
.
tolist
()
==
expected
.
expand
(
2
,
5
).
tolist
()
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'no CUDA'
)
...
@@ -59,10 +59,10 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
...
@@ -59,10 +59,10 @@ def test_grid_cluster_gpu(tensor): # pragma: no cover
output
=
grid_cluster
(
position
.
expand
(
2
,
5
,
2
),
size
)
output
=
grid_cluster
(
position
.
expand
(
2
,
5
,
2
),
size
)
assert
output
.
tolist
()
==
expected
.
expand
(
2
,
5
).
tolist
()
assert
output
.
tolist
()
==
expected
.
expand
(
2
,
5
).
tolist
()
expected
=
torch
.
LongTensor
([
0
,
1
,
3
,
2
,
4
])
position
=
position
.
repeat
(
2
,
1
)
batch
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
1
,
1
,
1
])
batch
=
torch
.
cuda
.
LongTensor
([
0
,
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
,
1
])
output
=
grid_cluster
(
position
,
size
,
batch
)
expected
=
torch
.
LongTensor
([
0
,
3
,
1
,
0
,
2
,
4
,
7
,
5
,
4
,
6
])
expected_batch
=
torch
.
LongTensor
([
0
,
0
,
0
,
0
,
1
,
1
,
1
,
1
])
output
,
reduced_batch
=
grid_cluster
(
position
,
size
,
batch
)
assert
output
.
cpu
().
tolist
()
==
expected
.
tolist
()
assert
output
.
cpu
().
tolist
()
==
expected
.
tolist
()
assert
reduced_batch
.
cpu
().
tolist
()
==
expected_batch
.
tolist
()
output
=
grid_cluster
(
position
.
expand
(
2
,
5
,
2
),
size
,
batch
.
expand
(
2
,
5
))
assert
output
.
cpu
().
tolist
()
==
expected
.
expand
(
2
,
5
).
tolist
()
torch_cluster/functions/grid.py
View file @
37778e99
...
@@ -43,6 +43,10 @@ def grid_cluster(position, size, batch=None):
...
@@ -43,6 +43,10 @@ def grid_cluster(position, size, batch=None):
func
=
get_func
(
'grid'
,
position
)
func
=
get_func
(
'grid'
,
position
)
func
(
C
,
cluster
,
position
,
size
,
c_max
)
func
(
C
,
cluster
,
position
,
size
,
c_max
)
cluster
=
cluster
.
squeeze
(
dim
=-
1
)
cluster
=
cluster
.
squeeze
(
dim
=-
1
)
cluster
=
consecutive
(
cluster
)
cluster
,
u
=
consecutive
(
cluster
)
return
cluster
if
batch
is
None
:
return
cluster
else
:
batch
=
(
u
/
c_max
[
1
:].
prod
()).
long
()
return
cluster
,
batch
torch_cluster/functions/utils.py
View file @
37778e99
...
@@ -22,7 +22,7 @@ def get_type(max, cuda):
...
@@ -22,7 +22,7 @@ def get_type(max, cuda):
return
torch
.
cuda
.
LongTensor
if
cuda
else
torch
.
LongTensor
return
torch
.
cuda
.
LongTensor
if
cuda
else
torch
.
LongTensor
def
consecutive
(
tensor
):
def
consecutive
(
tensor
,
return_batch
=
None
):
size
=
tensor
.
size
()
size
=
tensor
.
size
()
u
=
unique
(
tensor
.
view
(
-
1
))
u
=
unique
(
tensor
.
view
(
-
1
))
len
=
u
[
-
1
]
+
1
len
=
u
[
-
1
]
+
1
...
@@ -31,4 +31,4 @@ def consecutive(tensor):
...
@@ -31,4 +31,4 @@ def consecutive(tensor):
arg
=
type
(
len
)
arg
=
type
(
len
)
arg
[
u
]
=
torch
.
arange
(
0
,
max
,
out
=
type
(
max
))
arg
[
u
]
=
torch
.
arange
(
0
,
max
,
out
=
type
(
max
))
tensor
=
arg
[
tensor
.
view
(
-
1
)]
tensor
=
arg
[
tensor
.
view
(
-
1
)]
return
tensor
.
view
(
size
).
long
()
return
tensor
.
view
(
size
).
long
()
,
u
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment