Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
416a2603
Commit
416a2603
authored
Dec 17, 2017
by
rusty1s
Browse files
backward index impl
parent
880c8102
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
84 additions
and
52 deletions
+84
-52
test/test_max.py
test/test_max.py
+12
-3
torch_scatter/functions/__init__.py
torch_scatter/functions/__init__.py
+6
-12
torch_scatter/functions/scatter.py
torch_scatter/functions/scatter.py
+28
-5
torch_scatter/functions/utils.py
torch_scatter/functions/utils.py
+1
-1
torch_scatter/src/cpu.c
torch_scatter/src/cpu.c
+1
-0
torch_scatter/src/cpu.h
torch_scatter/src/cpu.h
+8
-0
torch_scatter/src/generic/cpu.c
torch_scatter/src/generic/cpu.c
+28
-31
No files found.
test/test_max.py
View file @
416a2603
...
...
@@ -27,10 +27,19 @@ def test_scatter_mean(str):
output
=
Variable
(
output
).
fill_
(
0
)
index
=
Variable
(
index
)
input
=
Variable
(
input
,
requires_grad
=
True
)
_
,
output_index
=
scatter_max_
(
output
,
index
,
input
,
dim
=
1
)
scatter_max_
(
output
,
index
,
input
,
dim
=
1
)
grad_output
=
[[
0
,
1
,
2
,
3
,
4
,
5
]
,
[
0
,
1
,
2
,
3
,
4
,
5
]]
grad_output
=
[[
1
0
,
2
0
,
3
0
,
4
0
,
5
0
,
60
],
[
15
,
2
5
,
3
5
,
4
5
,
55
,
6
5
]]
grad_output
=
Tensor
(
str
,
grad_output
)
output
.
backward
(
grad_output
)
assert
index
.
data
.
tolist
()
==
input
.
grad
.
data
.
tolist
()
# assert index.data.tolist() == input.grad.data.tolist()
# output = Variable(torch.FloatTensor([0, 0, 0, 0, 0]))
index
=
Variable
(
torch
.
LongTensor
([
3
,
4
,
4
,
2
,
1
]))
input
=
Variable
(
torch
.
FloatTensor
([
1
,
2
,
3
,
4
,
5
]),
requires_grad
=
True
)
output
,
output_index
=
scatter_max
(
index
,
input
)
# print(output_index)
output
.
backward
(
torch
.
FloatTensor
([
10
,
20
,
30
,
40
]))
print
(
input
.
grad
)
torch_scatter/functions/__init__.py
View file @
416a2603
...
...
@@ -6,8 +6,7 @@ from .utils import gen_output
def
scatter_add_
(
output
,
index
,
input
,
dim
=
0
):
scatter
(
'add'
,
dim
,
output
,
index
,
input
)
return
output
return
scatter
(
'add'
,
dim
,
output
,
index
,
input
)
def
scatter_add
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
...
...
@@ -16,8 +15,7 @@ def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
def
scatter_sub_
(
output
,
index
,
input
,
dim
=
0
):
scatter
(
'sub'
,
dim
,
output
,
index
,
input
)
return
output
return
scatter
(
'sub'
,
dim
,
output
,
index
,
input
)
def
scatter_sub
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
...
...
@@ -26,8 +24,7 @@ def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
def
scatter_mul_
(
output
,
index
,
input
,
dim
=
0
):
scatter
(
'mul'
,
dim
,
output
,
index
,
input
)
return
output
return
scatter
(
'mul'
,
dim
,
output
,
index
,
input
)
def
scatter_mul
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
1
):
...
...
@@ -36,8 +33,7 @@ def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
def
scatter_div_
(
output
,
index
,
input
,
dim
=
0
):
scatter
(
'div'
,
dim
,
output
,
index
,
input
)
return
output
return
scatter
(
'div'
,
dim
,
output
,
index
,
input
)
def
scatter_div
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
1
):
...
...
@@ -66,8 +62,7 @@ def scatter_max_(output, index, input, dim=0):
output_index
=
index
.
new
(
output
.
size
()).
fill_
(
-
1
)
else
:
output_index
=
Variable
(
index
.
data
.
new
(
output
.
size
()).
fill_
(
-
1
))
scatter
(
'max'
,
dim
,
output
,
index
,
input
,
output_index
)
return
output
,
output_index
return
scatter
(
'max'
,
dim
,
output
,
index
,
input
,
output_index
)
def
scatter_max
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
...
...
@@ -80,8 +75,7 @@ def scatter_min_(output, index, input, dim=0):
output_index
=
index
.
new
(
output
.
size
()).
fill_
(
-
1
)
else
:
output_index
=
Variable
(
index
.
data
.
new
(
output
.
size
()).
fill_
(
-
1
))
scatter
(
'min'
,
dim
,
output
,
index
,
input
,
output_index
)
return
output
,
output_index
return
scatter
(
'min'
,
dim
,
output
,
index
,
input
,
output_index
)
def
scatter_min
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
...
...
torch_scatter/functions/scatter.py
View file @
416a2603
...
...
@@ -6,6 +6,10 @@ from torch.autograd import Function
from
.._ext
import
ffi
def
_has_output_index
(
name
):
return
name
in
[
'max'
,
'min'
]
def
_scatter
(
name
,
dim
,
*
data
):
a
,
b
,
c
=
data
[:
3
]
...
...
@@ -31,6 +35,15 @@ def _scatter(name, dim, *data):
typename
=
type
(
data
[
0
]).
__name__
.
replace
(
'Tensor'
,
''
)
func
=
getattr
(
ffi
,
'scatter_{}_{}'
.
format
(
name
,
typename
))
func
(
dim
,
*
data
)
return
(
data
[
0
],
data
[
3
])
if
_has_output_index
(
name
)
else
data
[
0
]
def
_index_backward
(
dim
,
index
,
grad
,
grad_index
):
typename
=
type
(
grad
).
__name__
.
replace
(
'Tensor'
,
''
)
func
=
getattr
(
ffi
,
'index_backward_{}'
.
format
(
typename
))
output
=
grad
.
new
(
index
.
size
()).
fill_
(
0
)
func
(
dim
,
output
,
index
,
grad
,
grad_index
)
return
output
class
_Scatter
(
Function
):
...
...
@@ -44,21 +57,31 @@ class _Scatter(Function):
self
.
mark_dirty
(
data
[
0
])
# Mark output as dirty.
self
.
len
=
len
(
data
)
# Save number of arguments for backward step
self
.
save_for_backward
(
data
[
1
])
# Save index for backward step.
_scatter
(
self
.
name
,
self
.
dim
,
*
data
)
return
data
[
0
]
if
_has_output_index
(
self
.
name
):
self
.
save_for_backward
(
data
[
1
],
data
[
3
])
return
data
[
0
],
data
[
3
]
else
:
self
.
save_for_backward
(
data
[
1
])
return
data
[
0
]
def
backward
(
self
,
*
data
):
index
,
=
self
.
saved_variables
grad_output
=
grad_input
=
None
if
self
.
needs_input_grad
[
0
]:
grad_output
=
data
[
0
]
if
self
.
needs_input_grad
[
2
]:
# TODO: max and min
if
self
.
needs_input_grad
[
2
]
and
not
_has_output_index
(
self
.
name
):
index
,
=
self
.
saved_variables
grad_input
=
data
[
0
].
gather
(
self
.
dim
,
index
.
data
)
if
self
.
needs_input_grad
[
2
]
and
_has_output_index
(
self
.
name
):
index
,
grad_index
=
self
.
saved_variables
data
=
(
index
.
data
,
data
[
0
],
grad_index
.
data
)
grad_input
=
_index_backward
(
self
.
dim
,
*
data
)
return
(
grad_output
,
None
,
grad_input
)
+
(
None
,
)
*
(
self
.
len
-
3
)
...
...
torch_scatter/functions/utils.py
View file @
416a2603
...
...
@@ -11,4 +11,4 @@ def gen_output(index, input, dim, max_index, fill_value):
return
input
.
new
(
torch
.
Size
(
size
)).
fill_
(
fill_value
)
else
:
size
[
dim
]
=
max_index
.
data
[
0
]
return
Variable
(
input
.
new
(
torch
.
Size
(
size
)).
fill_
(
fill_value
))
return
Variable
(
input
.
data
.
new
(
torch
.
Size
(
size
)).
fill_
(
fill_value
))
torch_scatter/src/cpu.c
View file @
416a2603
...
...
@@ -3,6 +3,7 @@
#include "THTensorDimApply4.h"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real)
#define index_backward TH_CONCAT_2(index_backward_, Real)
inline
void
assertIndexInBoundaries
(
int
idx
,
int
size
,
int64_t
*
free
)
{
if
(
idx
<
0
||
idx
>=
size
)
{
THFree
(
free
);
THError
(
"Invalid index"
);
}
...
...
torch_scatter/src/cpu.h
View file @
416a2603
...
...
@@ -53,3 +53,11 @@ void scatter_min_Char (int dim, THCharTensor *output, THLongTensor *index, TH
void
scatter_min_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
THLongTensor
*
output_index
);
void
scatter_min_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
THLongTensor
*
output_index
);
void
scatter_min_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
THLongTensor
*
output_index
);
void
index_backward_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
grad
,
THLongTensor
*
grad_index
);
void
index_backward_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
grad
,
THLongTensor
*
grad_index
);
void
index_backward_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
grad
,
THLongTensor
*
grad_index
);
void
index_backward_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
grad
,
THLongTensor
*
grad_index
);
void
index_backward_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
grad
,
THLongTensor
*
grad_index
);
void
index_backward_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
grad
,
THLongTensor
*
grad_index
);
void
index_backward_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
grad
,
THLongTensor
*
grad_index
);
torch_scatter/src/generic/cpu.c
View file @
416a2603
...
...
@@ -3,75 +3,72 @@
#else
void
scatter_
(
add
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
int64_t
idx
;
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
idx
]
+=
*
(
input_data
+
i
*
input_stride
);
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
index_data
[
i
]]
+=
input_data
[
i
];
})
}
void
scatter_
(
sub
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
int64_t
idx
;
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
idx
]
-=
*
(
input_data
+
i
*
input_stride
);
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
index_data
[
i
]]
-=
input_data
[
i
];
})
}
void
scatter_
(
mul
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
int64_t
idx
;
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
idx
]
*=
*
(
input_data
+
i
*
input_stride
);
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
index_data
[
i
]]
*=
input_data
[
i
];
})
}
void
scatter_
(
div
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
int64_t
idx
;
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
idx
]
/=
*
(
input_data
+
i
*
input_stride
);
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
index_data
[
i
]]
/=
input_data
[
i
];
})
}
void
scatter_
(
mean
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THTensor
*
output_count
)
{
int64_t
idx
;
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
real
,
output_count
,
dim
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
idx
]
+=
*
(
input_data
+
i
*
input_stride
);
output_count_data
[
idx
]
++
;
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
index_data
[
i
]]
+=
input_data
[
i
];
output_count_data
[
index_data
[
i
]]
++
;
})
}
void
scatter_
(
max
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THLongTensor
*
output_index
)
{
int64_t
idx
;
real
old
,
new
;
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
int64_t
,
output_index
,
dim
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
old
=
output_data
[
idx
];
new
=
*
(
input_data
+
i
*
input_stride
);
if
(
new
>=
old
)
{
output_data
[
idx
]
=
new
;
output_index_data
[
idx
]
=
i
;
}
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
if
(
input_data
[
i
]
>=
output_data
[
index_data
[
i
]])
{
output_data
[
index_data
[
i
]]
=
input_data
[
i
];
output_index_data
[
index_data
[
i
]]
=
i
;
}
})
}
void
scatter_
(
min
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THLongTensor
*
output_index
)
{
int64_t
idx
;
real
old
,
new
;
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
int64_t
,
output_index
,
dim
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
old
=
output_data
[
idx
];
new
=
*
(
input_data
+
i
*
input_stride
);
if
(
new
<=
old
)
{
output_data
[
idx
]
=
new
;
output_index_data
[
idx
]
=
i
;
}
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
if
(
input_data
[
i
]
<=
output_data
[
index_data
[
i
]])
{
output_data
[
index_data
[
i
]]
=
input_data
[
i
];
output_index_data
[
index_data
[
i
]]
=
i
;
}
})
}
void
index_backward
(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
grad
,
THLongTensor
*
grad_index
)
{
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
grad
,
int64_t
,
grad_index
,
dim
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
if
(
grad_index_data
[
index_data
[
i
]]
==
i
)
output_data
[
index_data
[
i
]]
=
grad_data
[
i
];
})
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment