Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
7c46799e
Commit
7c46799e
authored
Jan 30, 2020
by
rusty1s
Browse files
doc fixes
parent
85940068
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
20 additions
and
11 deletions
+20
-11
torch_scatter/composite/logsumexp.py
torch_scatter/composite/logsumexp.py
+1
-1
torch_scatter/scatter.py
torch_scatter/scatter.py
+13
-10
torch_scatter/segment_coo.py
torch_scatter/segment_coo.py
+3
-0
torch_scatter/segment_csr.py
torch_scatter/segment_csr.py
+3
-0
No files found.
torch_scatter/composite/logsumexp.py
View file @
7c46799e
...
@@ -21,7 +21,7 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
...
@@ -21,7 +21,7 @@ def scatter_logsumexp(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
dim_size
=
out
.
size
(
dim
)
dim_size
=
out
.
size
(
dim
)
else
:
else
:
if
dim_size
is
None
:
if
dim_size
is
None
:
dim_size
=
int
(
index
.
max
()
.
item
(
)
+
1
)
dim_size
=
int
(
index
.
max
())
+
1
size
=
src
.
size
()
size
=
src
.
size
()
size
[
dim
]
=
dim_size
size
[
dim
]
=
dim_size
...
...
torch_scatter/scatter.py
View file @
7c46799e
...
@@ -10,20 +10,23 @@ try:
...
@@ -10,20 +10,23 @@ try:
except
OSError
:
except
OSError
:
warnings
.
warn
(
'Failed to load `scatter` binaries.'
)
warnings
.
warn
(
'Failed to load `scatter` binaries.'
)
def
placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
,
def
scatter_
placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
,
out
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
])
->
torch
.
Tensor
:
dim_size
:
Optional
[
int
])
->
torch
.
Tensor
:
raise
ImportError
raise
ImportError
return
src
def
arg_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
,
def
scatter_with_arg_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
]
dim
:
int
,
out
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
dim_size
:
Optional
[
int
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
ImportError
raise
ImportError
return
src
,
index
torch
.
ops
.
torch_scatter
.
scatter_sum
=
placeholder
torch
.
ops
.
torch_scatter
.
scatter_sum
=
scatter_
placeholder
torch
.
ops
.
torch_scatter
.
scatter_mean
=
placeholder
torch
.
ops
.
torch_scatter
.
scatter_mean
=
scatter_
placeholder
torch
.
ops
.
torch_scatter
.
scatter_min
=
arg_placeholder
torch
.
ops
.
torch_scatter
.
scatter_min
=
scatter_with_
arg_placeholder
torch
.
ops
.
torch_scatter
.
scatter_max
=
arg_placeholder
torch
.
ops
.
torch_scatter
.
scatter_max
=
scatter_with_
arg_placeholder
@
torch
.
jit
.
script
@
torch
.
jit
.
script
...
...
torch_scatter/segment_coo.py
View file @
7c46799e
...
@@ -14,16 +14,19 @@ except OSError:
...
@@ -14,16 +14,19 @@ except OSError:
out
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
])
->
torch
.
Tensor
:
dim_size
:
Optional
[
int
])
->
torch
.
Tensor
:
raise
ImportError
raise
ImportError
return
src
def
segment_coo_with_arg_placeholder
(
def
segment_coo_with_arg_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
],
out
:
Optional
[
torch
.
Tensor
],
dim_size
:
Optional
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
dim_size
:
Optional
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
ImportError
raise
ImportError
return
src
,
index
def
gather_coo_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
def
gather_coo_placeholder
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
raise
ImportError
raise
ImportError
return
src
torch
.
ops
.
torch_scatter
.
segment_sum_coo
=
segment_coo_placeholder
torch
.
ops
.
torch_scatter
.
segment_sum_coo
=
segment_coo_placeholder
torch
.
ops
.
torch_scatter
.
segment_mean_coo
=
segment_coo_placeholder
torch
.
ops
.
torch_scatter
.
segment_mean_coo
=
segment_coo_placeholder
...
...
torch_scatter/segment_csr.py
View file @
7c46799e
...
@@ -13,15 +13,18 @@ except OSError:
...
@@ -13,15 +13,18 @@ except OSError:
def
segment_csr_placeholder
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
def
segment_csr_placeholder
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
raise
ImportError
raise
ImportError
return
src
def
segment_csr_with_arg_placeholder
(
def
segment_csr_with_arg_placeholder
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
out
:
Optional
[
torch
.
Tensor
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
raise
ImportError
raise
ImportError
return
src
,
indptr
def
gather_csr_placeholder
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
def
gather_csr_placeholder
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
out
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
raise
ImportError
raise
ImportError
return
src
torch
.
ops
.
torch_scatter
.
segment_sum_csr
=
segment_csr_placeholder
torch
.
ops
.
torch_scatter
.
segment_sum_csr
=
segment_csr_placeholder
torch
.
ops
.
torch_scatter
.
segment_mean_csr
=
segment_csr_placeholder
torch
.
ops
.
torch_scatter
.
segment_mean_csr
=
segment_csr_placeholder
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment