Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
a7beacab
Commit
a7beacab
authored
Jan 12, 2018
by
rusty1s
Browse files
bugfix with tensor strides
parent
aeca7758
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
103 additions
and
29 deletions
+103
-29
setup.py
setup.py
+1
-1
test/backward.json
test/backward.json
+27
-0
test/forward.json
test/forward.json
+48
-6
torch_scatter/__init__.py
torch_scatter/__init__.py
+1
-1
torch_scatter/src/generic/cpu.c
torch_scatter/src/generic/cpu.c
+26
-21
No files found.
setup.py
View file @
a7beacab
...
@@ -4,7 +4,7 @@ from setuptools import setup, find_packages
...
@@ -4,7 +4,7 @@ from setuptools import setup, find_packages
import
build
# noqa
import
build
# noqa
__version__
=
'0.
1.3
'
__version__
=
'0.
2.0
'
url
=
'https://github.com/rusty1s/pytorch_scatter'
url
=
'https://github.com/rusty1s/pytorch_scatter'
install_requires
=
[
'cffi'
]
install_requires
=
[
'cffi'
]
...
...
test/backward.json
View file @
a7beacab
...
@@ -8,6 +8,15 @@
...
@@ -8,6 +8,15 @@
"grad"
:
[[
10
,
20
,
30
,
40
,
50
,
60
],
[
15
,
25
,
35
,
45
,
55
,
65
]],
"grad"
:
[[
10
,
20
,
30
,
40
,
50
,
60
],
[
15
,
25
,
35
,
45
,
55
,
65
]],
"expected"
:
[[
50
,
60
,
50
,
30
,
40
],
[
15
,
15
,
35
,
35
,
25
]]
"expected"
:
[[
50
,
60
,
50
,
30
,
40
],
[
15
,
15
,
35
,
35
,
25
]]
},
},
{
"name"
:
"add"
,
"index"
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
"input"
:
[[
5
,
2
],
[
2
,
5
],
[
4
,
3
],
[
1
,
3
]],
"dim"
:
0
,
"fill_value"
:
0
,
"grad"
:
[[
10
,
20
],
[
15
,
25
]],
"expected"
:
[[
10
,
20
],
[
15
,
25
],
[
15
,
25
],
[
10
,
20
]]
},
{
{
"name"
:
"mean"
,
"name"
:
"mean"
,
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
...
@@ -17,6 +26,15 @@
...
@@ -17,6 +26,15 @@
"grad"
:
[[
10
,
20
,
30
,
40
,
50
,
60
],
[
15
,
25
,
35
,
45
,
55
,
65
]],
"grad"
:
[[
10
,
20
,
30
,
40
,
50
,
60
],
[
15
,
25
,
35
,
45
,
55
,
65
]],
"expected"
:
[[
50
,
60
,
50
,
30
,
40
],
[
15
,
15
,
35
,
35
,
25
]]
"expected"
:
[[
50
,
60
,
50
,
30
,
40
],
[
15
,
15
,
35
,
35
,
25
]]
},
},
{
"name"
:
"mean"
,
"index"
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
"input"
:
[[
5
,
2
],
[
2
,
5
],
[
4
,
3
],
[
1
,
3
]],
"dim"
:
0
,
"fill_value"
:
0
,
"grad"
:
[[
10
,
20
],
[
15
,
25
]],
"expected"
:
[[
10
,
20
],
[
15
,
25
],
[
15
,
25
],
[
10
,
20
]]
},
{
{
"name"
:
"max"
,
"name"
:
"max"
,
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
...
@@ -25,5 +43,14 @@
...
@@ -25,5 +43,14 @@
"fill_value"
:
0
,
"fill_value"
:
0
,
"grad"
:
[[
10
,
20
,
30
,
40
,
50
,
60
],
[
15
,
25
,
35
,
45
,
55
,
65
]],
"grad"
:
[[
10
,
20
,
30
,
40
,
50
,
60
],
[
15
,
25
,
35
,
45
,
55
,
65
]],
"expected"
:
[[
50
,
60
,
0
,
30
,
40
],
[
0
,
15
,
0
,
35
,
25
]]
"expected"
:
[[
50
,
60
,
0
,
30
,
40
],
[
0
,
15
,
0
,
35
,
25
]]
},
{
"name"
:
"max"
,
"index"
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
"input"
:
[[
5
,
2
],
[
2
,
5
],
[
4
,
3
],
[
1
,
3
]],
"dim"
:
0
,
"fill_value"
:
0
,
"grad"
:
[[
10
,
20
],
[
15
,
25
]],
"expected"
:
[[
10
,
0
],
[
0
,
25
],
[
15
,
0
],
[
0
,
20
]]
}
}
]
]
test/forward.json
View file @
a7beacab
...
@@ -7,6 +7,14 @@
...
@@ -7,6 +7,14 @@
"fill_value"
:
0
,
"fill_value"
:
0
,
"expected"
:
[[
0
,
0
,
4
,
3
,
3
,
0
],
[
2
,
4
,
4
,
0
,
0
,
0
]]
"expected"
:
[[
0
,
0
,
4
,
3
,
3
,
0
],
[
2
,
4
,
4
,
0
,
0
,
0
]]
},
},
{
"name"
:
"add"
,
"index"
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
"input"
:
[[
5
,
2
],
[
2
,
5
],
[
4
,
3
],
[
1
,
3
]],
"dim"
:
0
,
"fill_value"
:
0
,
"expected"
:
[[
6
,
5
],
[
6
,
8
]]
},
{
{
"name"
:
"sub"
,
"name"
:
"sub"
,
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
...
@@ -15,6 +23,14 @@
...
@@ -15,6 +23,14 @@
"fill_value"
:
9
,
"fill_value"
:
9
,
"expected"
:
[[
9
,
9
,
5
,
6
,
6
,
9
],
[
7
,
5
,
5
,
9
,
9
,
9
]]
"expected"
:
[[
9
,
9
,
5
,
6
,
6
,
9
],
[
7
,
5
,
5
,
9
,
9
,
9
]]
},
},
{
"name"
:
"sub"
,
"index"
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
"input"
:
[[
5
,
2
],
[
2
,
5
],
[
4
,
3
],
[
1
,
3
]],
"dim"
:
0
,
"fill_value"
:
9
,
"expected"
:
[[
3
,
4
],
[
3
,
1
]]
},
{
{
"name"
:
"mul"
,
"name"
:
"mul"
,
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
...
@@ -23,6 +39,14 @@
...
@@ -23,6 +39,14 @@
"fill_value"
:
1
,
"fill_value"
:
1
,
"expected"
:
[[
1
,
1
,
4
,
3
,
2
,
0
],
[
0
,
4
,
3
,
1
,
1
,
1
]]
"expected"
:
[[
1
,
1
,
4
,
3
,
2
,
0
],
[
0
,
4
,
3
,
1
,
1
,
1
]]
},
},
{
"name"
:
"mul"
,
"index"
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
"input"
:
[[
5
,
2
],
[
2
,
5
],
[
4
,
3
],
[
1
,
3
]],
"dim"
:
0
,
"fill_value"
:
1
,
"expected"
:
[[
5
,
6
],
[
8
,
15
]]
},
{
{
"name"
:
"div"
,
"name"
:
"div"
,
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
...
@@ -32,12 +56,12 @@
...
@@ -32,12 +56,12 @@
"expected"
:
[[
1
,
1
,
0.25
,
0.5
,
0.5
,
1
],
[
0.5
,
0.25
,
0.5
,
1
,
1
,
1
]]
"expected"
:
[[
1
,
1
,
0.25
,
0.5
,
0.5
,
1
],
[
0.5
,
0.25
,
0.5
,
1
,
1
,
1
]]
},
},
{
{
"name"
:
"
mean
"
,
"name"
:
"
div
"
,
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
"index"
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
"input"
:
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]],
"input"
:
[[
4
,
2
],
[
2
,
1
],
[
4
,
2
]
,
[
1
,
2
]],
"dim"
:
1
,
"dim"
:
0
,
"fill_value"
:
0
,
"fill_value"
:
1
,
"expected"
:
[[
0
,
0
,
4
,
3
,
1
.5
,
0
],
[
1
,
4
,
2
,
0
,
0
,
0
]]
"expected"
:
[[
0.
2
5
,
0
.25
],
[
0.125
,
0
.5
]]
},
},
{
{
"name"
:
"max"
,
"name"
:
"max"
,
...
@@ -48,6 +72,15 @@
...
@@ -48,6 +72,15 @@
"expected"
:
[[
0
,
0
,
4
,
3
,
2
,
0
],
[
2
,
4
,
3
,
0
,
0
,
0
]],
"expected"
:
[[
0
,
0
,
4
,
3
,
2
,
0
],
[
2
,
4
,
3
,
0
,
0
,
0
]],
"expected_arg"
:
[[
-1
,
-1
,
3
,
4
,
0
,
1
],
[
1
,
4
,
3
,
-1
,
-1
,
-1
]]
"expected_arg"
:
[[
-1
,
-1
,
3
,
4
,
0
,
1
],
[
1
,
4
,
3
,
-1
,
-1
,
-1
]]
},
},
{
"name"
:
"max"
,
"index"
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
"input"
:
[[
5
,
2
],
[
2
,
5
],
[
4
,
3
],
[
1
,
3
]],
"dim"
:
0
,
"fill_value"
:
0
,
"expected"
:
[[
5
,
3
],
[
4
,
5
]],
"expected_arg"
:
[[
0
,
3
],
[
2
,
1
]]
},
{
{
"name"
:
"min"
,
"name"
:
"min"
,
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
"index"
:
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]],
...
@@ -56,5 +89,14 @@
...
@@ -56,5 +89,14 @@
"fill_value"
:
9
,
"fill_value"
:
9
,
"expected"
:
[[
9
,
9
,
4
,
3
,
1
,
0
],
[
0
,
4
,
1
,
9
,
9
,
9
]],
"expected"
:
[[
9
,
9
,
4
,
3
,
1
,
0
],
[
0
,
4
,
1
,
9
,
9
,
9
]],
"expected_arg"
:
[[
-1
,
-1
,
3
,
4
,
2
,
1
],
[
0
,
4
,
2
,
-1
,
-1
,
-1
]]
"expected_arg"
:
[[
-1
,
-1
,
3
,
4
,
2
,
1
],
[
0
,
4
,
2
,
-1
,
-1
,
-1
]]
},
{
"name"
:
"min"
,
"index"
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
"input"
:
[[
5
,
2
],
[
2
,
5
],
[
4
,
3
],
[
1
,
3
]],
"dim"
:
0
,
"fill_value"
:
9
,
"expected"
:
[[
1
,
2
],
[
2
,
3
]],
"expected_arg"
:
[[
3
,
0
],
[
1
,
2
]]
}
}
]
]
torch_scatter/__init__.py
View file @
a7beacab
...
@@ -6,7 +6,7 @@ from .functions.mean import scatter_mean_, scatter_mean
...
@@ -6,7 +6,7 @@ from .functions.mean import scatter_mean_, scatter_mean
from
.functions.max
import
scatter_max_
,
scatter_max
from
.functions.max
import
scatter_max_
,
scatter_max
from
.functions.min
import
scatter_min_
,
scatter_min
from
.functions.min
import
scatter_min_
,
scatter_min
__version__
=
'0.
1.3
'
__version__
=
'0.
2.0
'
__all__
=
[
__all__
=
[
'scatter_add_'
,
'scatter_add'
,
'scatter_sub_'
,
'scatter_sub'
,
'scatter_add_'
,
'scatter_add'
,
'scatter_sub_'
,
'scatter_sub'
,
...
...
torch_scatter/src/generic/cpu.c
View file @
a7beacab
...
@@ -3,62 +3,67 @@
...
@@ -3,62 +3,67 @@
#else
#else
void
scatter_
(
mul
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
void
scatter_
(
mul
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
int64_t
i
;
int64_t
i
,
idx
;
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
idx
=
*
(
index_data
+
i
*
index_stride
);
output_data
[
index_data
[
i
]]
*=
input_data
[
i
];
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
idx
*
output_stride
]
*=
*
(
input_data
+
i
*
input_stride
);
})
})
}
}
void
scatter_
(
div
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
void
scatter_
(
div
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
int64_t
i
;
int64_t
i
,
idx
;
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
idx
=
*
(
index_data
+
i
*
index_stride
);
output_data
[
index_data
[
i
]]
/=
input_data
[
i
];
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
idx
*
output_stride
]
/=
*
(
input_data
+
i
*
input_stride
);
})
})
}
}
void
scatter_
(
mean
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THTensor
*
count
)
{
void
scatter_
(
mean
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THTensor
*
count
)
{
int64_t
i
;
int64_t
i
,
idx
;
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
real
,
count
,
dim
,
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
real
,
count
,
dim
,
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
i
ndex_data
[
i
]
]
+=
input_data
[
i
]
;
output_data
[
i
dx
*
output_stride
]
+=
*
(
input_data
+
i
*
input_stride
)
;
c
ou
n
t_data
[
i
ndex_data
[
i
]
]
++
;
ou
tpu
t_data
[
i
dx
*
count_stride
]
++
;
})
})
}
}
void
scatter_
(
max
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THLongTensor
*
arg
)
{
void
scatter_
(
max
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THLongTensor
*
arg
)
{
int64_t
i
;
int64_t
i
,
idx
;
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
int64_t
,
arg
,
dim
,
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
int64_t
,
arg
,
dim
,
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
idx
=
*
(
index_data
+
i
*
index_stride
);
if
(
input_data
[
i
]
>=
output_data
[
index_data
[
i
]])
{
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
index_data
[
i
]]
=
input_data
[
i
];
if
(
*
(
input_data
+
i
*
input_stride
)
>=
*
(
output_data
+
idx
*
output_stride
))
{
arg_data
[
index_data
[
i
]]
=
i
;
output_data
[
idx
*
output_stride
]
=
*
(
input_data
+
i
*
input_stride
);
arg_data
[
idx
*
arg_stride
]
=
i
;
}
}
})
})
}
}
void
scatter_
(
min
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THLongTensor
*
arg
)
{
void
scatter_
(
min
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THLongTensor
*
arg
)
{
int64_t
i
;
int64_t
i
,
idx
;
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
int64_t
,
arg
,
dim
,
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
int64_t
,
arg
,
dim
,
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
assertIndexInBoundaries
(
index_data
[
i
],
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
idx
=
*
(
index_data
+
i
*
index_stride
);
if
(
input_data
[
i
]
<=
output_data
[
index_data
[
i
]])
{
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
index_data
[
i
]]
=
input_data
[
i
];
if
(
*
(
input_data
+
i
*
input_stride
)
<=
*
(
output_data
+
idx
*
output_stride
))
{
arg_data
[
index_data
[
i
]]
=
i
;
output_data
[
idx
*
output_stride
]
=
*
(
input_data
+
i
*
input_stride
);
arg_data
[
idx
*
arg_stride
]
=
i
;
}
}
})
})
}
}
void
index_backward
(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
grad
,
THLongTensor
*
arg
)
{
void
index_backward
(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
grad
,
THLongTensor
*
arg
)
{
int64_t
i
;
int64_t
i
,
idx
;
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
grad
,
int64_t
,
arg
,
dim
,
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
grad
,
int64_t
,
arg
,
dim
,
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
if
(
arg_data
[
index_data
[
i
]]
==
i
)
output_data
[
i
]
=
grad_data
[
index_data
[
i
]];
idx
=
*
(
index_data
+
i
*
index_stride
);
if
(
*
(
arg_data
+
idx
*
arg_stride
)
==
i
)
output_data
[
i
*
output_stride
]
=
*
(
grad_data
+
idx
*
grad_stride
);
})
})
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment