Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
f60388a0
Commit
f60388a0
authored
Dec 21, 2017
by
rusty1s
Browse files
rename
parent
3f1346dc
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
20 additions
and
17 deletions
+20
-17
torch_scatter/functions/div.py
torch_scatter/functions/div.py
+2
-2
torch_scatter/functions/max.py
torch_scatter/functions/max.py
+2
-2
torch_scatter/functions/mean.py
torch_scatter/functions/mean.py
+2
-2
torch_scatter/functions/min.py
torch_scatter/functions/min.py
+2
-2
torch_scatter/functions/mul.py
torch_scatter/functions/mul.py
+2
-2
torch_scatter/functions/scatter.py
torch_scatter/functions/scatter.py
+2
-2
torch_scatter/functions/sub.py
torch_scatter/functions/sub.py
+2
-2
torch_scatter/functions/utils.py
torch_scatter/functions/utils.py
+6
-3
No files found.
torch_scatter/functions/div.py
View file @
f60388a0
...
@@ -8,6 +8,6 @@ def scatter_div_(output, index, input, dim=0):
...
@@ -8,6 +8,6 @@ def scatter_div_(output, index, input, dim=0):
return
scatter
(
'div'
,
dim
,
output
,
index
,
input
)
return
scatter
(
'div'
,
dim
,
output
,
index
,
input
)
def
scatter_div
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
1
):
def
scatter_div
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
1
):
output
=
gen_output
(
index
,
input
,
dim
,
max_index
,
fill_value
)
output
=
gen_output
(
index
,
input
,
dim
,
size
,
fill_value
)
scatter_div_
(
output
,
index
,
input
,
dim
)
scatter_div_
(
output
,
index
,
input
,
dim
)
torch_scatter/functions/max.py
View file @
f60388a0
...
@@ -12,6 +12,6 @@ def scatter_max_(output, index, input, dim=0):
...
@@ -12,6 +12,6 @@ def scatter_max_(output, index, input, dim=0):
return
scatter
(
'max'
,
dim
,
output
,
index
,
input
,
arg_output
)
return
scatter
(
'max'
,
dim
,
output
,
index
,
input
,
arg_output
)
def
scatter_max
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
def
scatter_max
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
output
=
gen_output
(
index
,
input
,
dim
,
max_index
,
fill_value
)
output
=
gen_output
(
index
,
input
,
dim
,
size
,
fill_value
)
return
scatter_max_
(
output
,
index
,
input
,
dim
)
return
scatter_max_
(
output
,
index
,
input
,
dim
)
torch_scatter/functions/mean.py
View file @
f60388a0
...
@@ -12,6 +12,6 @@ def scatter_mean_(output, index, input, dim=0):
...
@@ -12,6 +12,6 @@ def scatter_mean_(output, index, input, dim=0):
return
output
return
output
def
scatter_mean
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
def
scatter_mean
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
output
=
gen_output
(
index
,
input
,
dim
,
max_index
,
fill_value
)
output
=
gen_output
(
index
,
input
,
dim
,
size
,
fill_value
)
return
scatter_mean_
(
output
,
index
,
input
,
dim
)
return
scatter_mean_
(
output
,
index
,
input
,
dim
)
torch_scatter/functions/min.py
View file @
f60388a0
...
@@ -9,6 +9,6 @@ def scatter_min_(output, index, input, dim=0):
...
@@ -9,6 +9,6 @@ def scatter_min_(output, index, input, dim=0):
return
scatter
(
'min'
,
dim
,
output
,
index
,
input
,
arg_output
)
return
scatter
(
'min'
,
dim
,
output
,
index
,
input
,
arg_output
)
def
scatter_min
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
def
scatter_min
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
output
=
gen_output
(
index
,
input
,
dim
,
max_index
,
fill_value
)
output
=
gen_output
(
index
,
input
,
dim
,
size
,
fill_value
)
return
scatter_min_
(
output
,
index
,
input
,
dim
)
return
scatter_min_
(
output
,
index
,
input
,
dim
)
torch_scatter/functions/mul.py
View file @
f60388a0
...
@@ -8,6 +8,6 @@ def scatter_mul_(output, index, input, dim=0):
...
@@ -8,6 +8,6 @@ def scatter_mul_(output, index, input, dim=0):
return
scatter
(
'mul'
,
dim
,
output
,
index
,
input
)
return
scatter
(
'mul'
,
dim
,
output
,
index
,
input
)
def
scatter_mul
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
1
):
def
scatter_mul
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
1
):
output
=
gen_output
(
index
,
input
,
dim
,
max_index
,
fill_value
)
output
=
gen_output
(
index
,
input
,
dim
,
size
,
fill_value
)
return
scatter_mul_
(
output
,
index
,
input
,
dim
)
return
scatter_mul_
(
output
,
index
,
input
,
dim
)
torch_scatter/functions/scatter.py
View file @
f60388a0
...
@@ -39,12 +39,12 @@ def _scatter(name, dim, *data):
...
@@ -39,12 +39,12 @@ def _scatter(name, dim, *data):
return
(
data
[
0
],
data
[
3
])
if
has_arg_output
(
name
)
else
data
[
0
]
return
(
data
[
0
],
data
[
3
])
if
has_arg_output
(
name
)
else
data
[
0
]
def
index_backward
(
dim
,
index
,
grad
,
arg
_grad
):
def
index_backward
(
dim
,
index
,
grad
,
arg
):
typename
=
type
(
grad
).
__name__
.
replace
(
'Tensor'
,
''
)
typename
=
type
(
grad
).
__name__
.
replace
(
'Tensor'
,
''
)
cuda
=
'cuda_'
if
grad
.
is_cuda
else
''
cuda
=
'cuda_'
if
grad
.
is_cuda
else
''
func
=
getattr
(
ffi
,
'index_backward_{}{}'
.
format
(
cuda
,
typename
))
func
=
getattr
(
ffi
,
'index_backward_{}{}'
.
format
(
cuda
,
typename
))
output
=
grad
.
new
(
index
.
size
()).
fill_
(
0
)
output
=
grad
.
new
(
index
.
size
()).
fill_
(
0
)
func
(
dim
,
output
,
index
,
grad
,
arg
_grad
)
func
(
dim
,
output
,
index
,
grad
,
arg
)
return
output
return
output
...
...
torch_scatter/functions/sub.py
View file @
f60388a0
...
@@ -7,6 +7,6 @@ def scatter_sub_(output, index, input, dim=0):
...
@@ -7,6 +7,6 @@ def scatter_sub_(output, index, input, dim=0):
return
output
.
scatter_add_
(
dim
,
index
,
-
input
)
return
output
.
scatter_add_
(
dim
,
index
,
-
input
)
def
scatter_sub
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
def
scatter_sub
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
output
=
gen_output
(
index
,
input
,
dim
,
max_index
,
fill_value
)
output
=
gen_output
(
index
,
input
,
dim
,
size
,
fill_value
)
return
scatter_sub_
(
output
,
index
,
input
,
dim
)
return
scatter_sub_
(
output
,
index
,
input
,
dim
)
torch_scatter/functions/utils.py
View file @
f60388a0
...
@@ -9,8 +9,11 @@ def gen_filled_tensor(input, size, fill_value):
...
@@ -9,8 +9,11 @@ def gen_filled_tensor(input, size, fill_value):
return
Variable
(
input
.
data
.
new
(
size
).
fill_
(
fill_value
))
return
Variable
(
input
.
data
.
new
(
size
).
fill_
(
fill_value
))
def
gen_output
(
index
,
input
,
dim
,
max_index
,
fill_value
):
def
gen_output
(
index
,
input
,
dim
,
dim_size
,
fill_value
):
max_index
=
index
.
max
()
+
1
if
max_index
is
None
else
max_index
if
dim_size
is
None
:
dim_size
=
index
.
max
()
+
1
dim_size
=
dim
.
size
if
torch
.
is_tensor
(
input
)
else
dim_size
.
data
[
0
]
size
=
list
(
index
.
size
())
size
=
list
(
index
.
size
())
size
[
dim
]
=
max_index
if
torch
.
is_tensor
(
input
)
else
max_index
.
data
[
0
]
size
[
dim
]
=
dim_size
return
gen_filled_tensor
(
input
,
torch
.
Size
(
size
),
fill_value
)
return
gen_filled_tensor
(
input
,
torch
.
Size
(
size
),
fill_value
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment