Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
62815576
Commit
62815576
authored
Jan 29, 2020
by
rusty1s
Browse files
moved extensions to torch.ops
parent
0a221ab8
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
55 additions
and
41 deletions
+55
-41
torch_scatter/segment.py
torch_scatter/segment.py
+55
-28
torch_scatter/utils/ext.py
torch_scatter/utils/ext.py
+0
-13
No files found.
torch_scatter/segment.py
View file @
62815576
import
torch
import
torch
from
torch_scatter
import
segment_cpu
,
gather_cpu
from
torch_scatter.helpers
import
min_value
,
max_value
from
torch_scatter.helpers
import
min_value
,
max_value
if
torch
.
cuda
.
is_available
():
from
torch_scatter
import
segment_cuda
,
gather_cuda
def
seg
(
is_cuda
):
return
segment_cuda
if
is_cuda
else
segment_cpu
def
gat
(
is_cuda
):
return
gather_cuda
if
is_cuda
else
gather_cpu
class
SegmentCOO
(
torch
.
autograd
.
Function
):
class
SegmentCOO
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
...
@@ -37,7 +24,12 @@ class SegmentCOO(torch.autograd.Function):
...
@@ -37,7 +24,12 @@ class SegmentCOO(torch.autograd.Function):
out
=
src
.
new_full
(
size
,
fill_value
)
out
=
src
.
new_full
(
size
,
fill_value
)
out
,
arg_out
=
seg
(
src
.
is_cuda
).
segment_coo
(
src
,
index
,
out
,
reduce
)
if
src
.
is_cuda
:
out
,
arg_out
=
torch
.
ops
.
torch_scatter_cuda
.
segment_coo
(
src
,
index
,
out
,
reduce
)
else
:
out
,
arg_out
=
torch
.
ops
.
torch_scatter_cpu
.
segment_coo
(
src
,
index
,
out
,
reduce
)
if
fill_value
!=
0
:
if
fill_value
!=
0
:
out
.
masked_fill_
(
out
==
fill_value
,
0
)
out
.
masked_fill_
(
out
==
fill_value
,
0
)
...
@@ -56,25 +48,39 @@ class SegmentCOO(torch.autograd.Function):
...
@@ -56,25 +48,39 @@ class SegmentCOO(torch.autograd.Function):
grad_src
=
None
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
reduce
==
'sum'
or
ctx
.
reduce
==
'add'
:
if
ctx
.
reduce
==
'sum'
or
ctx
.
reduce
==
'add'
:
grad_src
=
gat
(
grad_out
.
is_cuda
).
gather_coo
(
if
grad_out
.
is_cuda
:
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
grad_src
=
torch
.
ops
.
torch_scatter_cuda
.
gather_coo
(
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
else
:
grad_src
=
torch
.
ops
.
torch_scatter_cpu
.
gather_coo
(
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
elif
ctx
.
reduce
==
'mean'
:
elif
ctx
.
reduce
==
'mean'
:
grad_src
=
gat
(
grad_out
.
is_cuda
).
gather_coo
(
if
grad_out
.
is_cuda
:
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
grad_src
=
torch
.
ops
.
torch_scatter_cuda
.
gather_coo
(
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
else
:
grad_src
=
torch
.
ops
.
torch_scatter_cpu
.
gather_coo
(
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
count
=
arg_out
# Gets pre-computed on GPU but not on CPU.
count
=
arg_out
# Gets pre-computed on GPU but not on CPU.
if
count
is
None
:
if
count
is
None
:
size
=
list
(
index
.
size
())
size
=
list
(
index
.
size
())
size
[
-
1
]
=
grad_out
.
size
(
index
.
dim
()
-
1
)
size
[
-
1
]
=
grad_out
.
size
(
index
.
dim
()
-
1
)
count
=
segment
_cpu
.
segment_coo
(
count
=
torch
.
ops
.
torch_scatter
_cpu
.
segment_coo
(
torch
.
ones_like
(
index
,
dtype
=
grad_out
.
dtype
),
index
,
torch
.
ones_like
(
index
,
dtype
=
grad_out
.
dtype
),
index
,
grad_out
.
new_zeros
(
size
),
'sum'
)[
0
].
clamp_
(
min
=
1
)
grad_out
.
new_zeros
(
size
),
'sum'
)[
0
].
clamp_
(
min
=
1
)
count
=
gat
(
grad_out
.
is_cuda
).
gather_coo
(
if
grad_out
.
is_cuda
:
count
,
index
,
count
.
new_empty
(
src_size
[:
index
.
dim
()]))
count
=
torch
.
ops
.
torch_scatter_cuda
.
gather_coo
(
count
,
index
,
count
.
new_empty
(
src_size
[:
index
.
dim
()]))
else
:
count
=
torch
.
ops
.
torch_scatter_cpu
.
gather_coo
(
count
,
index
,
count
.
new_empty
(
src_size
[:
index
.
dim
()]))
for
_
in
range
(
grad_out
.
dim
()
-
index
.
dim
()):
for
_
in
range
(
grad_out
.
dim
()
-
index
.
dim
()):
count
=
count
.
unsqueeze
(
-
1
)
count
=
count
.
unsqueeze
(
-
1
)
grad_src
.
div_
(
count
)
grad_src
.
div_
(
count
)
elif
ctx
.
reduce
==
'min'
or
ctx
.
reduce
==
'max'
:
elif
ctx
.
reduce
==
'min'
or
ctx
.
reduce
==
'max'
:
src_size
[
index
.
dim
()
-
1
]
+=
1
src_size
[
index
.
dim
()
-
1
]
+=
1
grad_src
=
grad_out
.
new_zeros
(
src_size
).
scatter_
(
grad_src
=
grad_out
.
new_zeros
(
src_size
).
scatter_
(
...
@@ -95,7 +101,13 @@ class SegmentCSR(torch.autograd.Function):
...
@@ -95,7 +101,13 @@ class SegmentCSR(torch.autograd.Function):
ctx
.
reduce
=
reduce
ctx
.
reduce
=
reduce
ctx
.
src_size
=
list
(
src
.
size
())
ctx
.
src_size
=
list
(
src
.
size
())
out
,
arg_out
=
seg
(
src
.
is_cuda
).
segment_csr
(
src
,
indptr
,
out
,
reduce
)
if
src
.
is_cuda
:
out
,
arg_out
=
torch
.
ops
.
torch_scatter_cuda
.
segment_csr
(
src
,
indptr
,
out
,
reduce
)
else
:
out
,
arg_out
=
torch
.
ops
.
torch_scatter_cpu
.
segment_csr
(
src
,
indptr
,
out
,
reduce
)
ctx
.
save_for_backward
(
indptr
,
arg_out
)
ctx
.
save_for_backward
(
indptr
,
arg_out
)
return
out
if
arg_out
is
None
else
(
out
,
arg_out
)
return
out
if
arg_out
is
None
else
(
out
,
arg_out
)
...
@@ -106,16 +118,31 @@ class SegmentCSR(torch.autograd.Function):
...
@@ -106,16 +118,31 @@ class SegmentCSR(torch.autograd.Function):
grad_src
=
None
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
reduce
==
'sum'
or
ctx
.
reduce
==
'add'
:
if
ctx
.
reduce
==
'sum'
or
ctx
.
reduce
==
'add'
:
grad_src
=
gat
(
grad_out
.
is_cuda
).
gather_csr
(
if
grad_out
.
is_cuda
:
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
))
grad_src
=
torch
.
ops
.
torch_scatter_cuda
.
gather_csr
(
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
))
else
:
grad_src
=
torch
.
ops
.
torch_scatter_cpu
.
gather_csr
(
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
))
elif
ctx
.
reduce
==
'mean'
:
elif
ctx
.
reduce
==
'mean'
:
grad_src
=
gat
(
grad_out
.
is_cuda
).
gather_csr
(
if
grad_out
.
is_cuda
:
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
))
grad_src
=
torch
.
ops
.
torch_scatter_cuda
.
gather_csr
(
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
))
else
:
grad_src
=
torch
.
ops
.
torch_scatter_cpu
.
gather_csr
(
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
))
indptr1
=
indptr
.
narrow
(
-
1
,
0
,
indptr
.
size
(
-
1
)
-
1
)
indptr1
=
indptr
.
narrow
(
-
1
,
0
,
indptr
.
size
(
-
1
)
-
1
)
indptr2
=
indptr
.
narrow
(
-
1
,
1
,
indptr
.
size
(
-
1
)
-
1
)
indptr2
=
indptr
.
narrow
(
-
1
,
1
,
indptr
.
size
(
-
1
)
-
1
)
count
=
(
indptr2
-
indptr1
).
to
(
grad_src
.
dtype
)
count
=
(
indptr2
-
indptr1
).
to
(
grad_src
.
dtype
)
count
=
gat
(
grad_out
.
is_cuda
).
gather_csr
(
if
grad_out
.
is_cuda
:
count
,
indptr
,
count
.
new_empty
(
src_size
[:
indptr
.
dim
()]))
count
=
torch
.
ops
.
torch_scatter_cuda
.
gather_csr
(
count
,
indptr
,
count
.
new_empty
(
src_size
[:
indptr
.
dim
()]))
else
:
count
=
torch
.
ops
.
torch_scatter_cpu
.
gather_csr
(
count
,
indptr
,
count
.
new_empty
(
src_size
[:
indptr
.
dim
()]))
for
_
in
range
(
grad_out
.
dim
()
-
indptr
.
dim
()):
for
_
in
range
(
grad_out
.
dim
()
-
indptr
.
dim
()):
count
=
count
.
unsqueeze
(
-
1
)
count
=
count
.
unsqueeze
(
-
1
)
grad_src
.
div_
(
count
)
grad_src
.
div_
(
count
)
...
...
torch_scatter/utils/ext.py
deleted
100644 → 0
View file @
0a221ab8
import
torch
import
torch_scatter.scatter_cpu
if
torch
.
cuda
.
is_available
():
import
torch_scatter.scatter_cuda
def
get_func
(
name
,
tensor
):
if
tensor
.
is_cuda
:
module
=
torch_scatter
.
scatter_cuda
else
:
module
=
torch_scatter
.
scatter_cpu
return
getattr
(
module
,
name
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment