Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
80a7dc52
Commit
80a7dc52
authored
Jan 12, 2020
by
rusty1s
Browse files
all tests on CPU+GPU
parent
5db00866
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
138 additions
and
23 deletions
+138
-23
benchmark/gather.py
benchmark/gather.py
+3
-1
cpu/gather.cpp
cpu/gather.cpp
+119
-4
cpu/segment.cpp
cpu/segment.cpp
+2
-2
test/test_gather.py
test/test_gather.py
+2
-9
test/test_segment.py
test/test_segment.py
+1
-4
torch_scatter/segment.py
torch_scatter/segment.py
+11
-3
No files found.
benchmark/gather.py
View file @
80a7dc52
...
@@ -7,7 +7,6 @@ from scipy.io import loadmat
...
@@ -7,7 +7,6 @@ from scipy.io import loadmat
from
torch_scatter
import
gather_coo
,
gather_csr
from
torch_scatter
import
gather_coo
,
gather_csr
from
scatter_segment
import
iters
,
sizes
from
scatter_segment
import
short_rows
,
long_rows
,
download
,
bold
from
scatter_segment
import
short_rows
,
long_rows
,
download
,
bold
...
@@ -125,6 +124,9 @@ if __name__ == '__main__':
...
@@ -125,6 +124,9 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--with_backward'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--with_backward'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
iters
=
1
if
args
.
device
==
'cpu'
else
20
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
sizes
=
sizes
[:
3
]
if
args
.
device
==
'cpu'
else
sizes
for
_
in
range
(
10
):
# Warmup.
for
_
in
range
(
10
):
# Warmup.
torch
.
randn
(
100
,
100
,
device
=
args
.
device
).
sum
()
torch
.
randn
(
100
,
100
,
device
=
args
.
device
).
sum
()
...
...
cpu/gather.cpp
View file @
80a7dc52
#include <torch/extension.h>
#include <torch/extension.h>
#include "compat.h"
#include "index_info.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at
::
Tensor
gather_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
Tensor
gather_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
...
@@ -8,8 +11,59 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
...
@@ -8,8 +11,59 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
CHECK_CPU
(
indptr
);
CHECK_CPU
(
indptr
);
if
(
out_opt
.
has_value
())
if
(
out_opt
.
has_value
())
CHECK_CPU
(
out_opt
.
value
());
CHECK_CPU
(
out_opt
.
value
());
AT_ASSERTM
(
false
,
"Not yet implemented"
);
return
src
;
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
(),
"Input mismatch"
);
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
indptr
.
size
(
i
),
"Input mismatch"
);
src
=
src
.
contiguous
();
auto
gather_dim
=
indptr
.
dim
()
-
1
;
AT_ASSERTM
(
src
.
size
(
gather_dim
)
==
indptr
.
size
(
gather_dim
)
-
1
,
"Input mismatch"
);
at
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
gather_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
),
"Input mismatch"
);
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
gather_dim
]
=
*
indptr
.
flatten
()[
-
1
].
DATA_PTR
<
int64_t
>
();
out
=
at
::
empty
(
sizes
,
src
.
options
());
}
auto
N
=
src
.
size
(
gather_dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
K
=
src
.
numel
()
/
N
;
auto
E
=
out
.
size
(
gather_dim
);
auto
indptr_info
=
getTensorInfo
<
int64_t
>
(
indptr
);
auto
stride
=
indptr_info
.
strides
[
indptr_info
.
dims
-
1
];
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"gather_csr"
,
[
&
]
{
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
vals
[
K
];
int64_t
row_start
,
row_end
;
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
n
,
indptr_info
);
row_start
=
indptr_info
.
data
[
offset
];
row_end
=
indptr_info
.
data
[
offset
+
stride
];
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
vals
[
k
]
=
src_data
[
n
*
K
+
k
];
}
offset
=
(
n
/
(
indptr
.
size
(
-
1
)
-
1
))
*
E
*
K
;
for
(
int64_t
e
=
row_start
;
e
<
row_end
;
e
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
out_data
[
offset
+
e
*
K
+
k
]
=
vals
[
k
];
}
}
}
});
return
out
;
}
}
at
::
Tensor
gather_coo
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
gather_coo
(
at
::
Tensor
src
,
at
::
Tensor
index
,
...
@@ -18,8 +72,69 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
...
@@ -18,8 +72,69 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
CHECK_CPU
(
index
);
CHECK_CPU
(
index
);
if
(
out_opt
.
has_value
())
if
(
out_opt
.
has_value
())
CHECK_CPU
(
out_opt
.
value
());
CHECK_CPU
(
out_opt
.
value
());
AT_ASSERTM
(
false
,
"Not yet implemented"
);
return
src
;
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
(),
"Input mismatch"
);
for
(
int
i
=
0
;
i
<
index
.
dim
()
-
1
;
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
index
.
size
(
i
),
"Input mismatch"
);
src
=
src
.
contiguous
();
auto
gather_dim
=
index
.
dim
()
-
1
;
at
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
AT_ASSERTM
(
out
.
size
(
i
)
==
index
.
size
(
i
),
"Input mismatch"
);
for
(
int
i
=
index
.
dim
()
+
1
;
i
<
src
.
dim
();
i
++
)
AT_ASSERTM
(
out
.
size
(
i
)
==
src
.
size
(
i
),
"Input mismatch"
);
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
gather_dim
]
=
index
.
size
(
gather_dim
);
out
=
at
::
empty
(
sizes
,
src
.
options
());
}
auto
E_1
=
index
.
numel
()
/
out
.
size
(
gather_dim
);
auto
E_2
=
index
.
size
(
gather_dim
);
auto
K
=
out
.
numel
()
/
index
.
numel
();
auto
N
=
src
.
size
(
gather_dim
);
auto
index_info
=
getTensorInfo
<
int64_t
>
(
index
);
auto
stride
=
index_info
.
strides
[
index_info
.
dims
-
1
];
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"gather_coo"
,
[
&
]
{
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
vals
[
K
];
int64_t
idx
,
next_idx
;
for
(
int
e_1
=
0
;
e_1
<
E_1
;
e_1
++
)
{
int
offset
=
IndexToOffset
<
int64_t
>::
get
(
e_1
*
E_2
,
index_info
);
idx
=
index_info
.
data
[
offset
];
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
vals
[
k
]
=
src_data
[
e_1
*
N
*
K
+
idx
*
K
+
k
];
}
for
(
int
e_2
=
0
;
e_2
<
E_2
;
e_2
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
out_data
[
e_1
*
E_2
*
K
+
e_2
*
K
+
k
]
=
vals
[
k
];
}
if
(
e_2
<
E_2
-
1
)
{
next_idx
=
index_info
.
data
[
offset
+
(
e_2
+
1
)
*
stride
];
assert
(
idx
<=
next_idx
);
if
(
idx
!=
next_idx
)
{
idx
=
next_idx
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
vals
[
k
]
=
src_data
[
e_1
*
N
*
K
+
idx
*
K
+
k
];
}
}
}
}
}
});
return
out
;
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
cpu/segment.cpp
View file @
80a7dc52
...
@@ -184,7 +184,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -184,7 +184,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
}
auto
E
=
index
.
numel
();
auto
E_1
=
index
.
numel
()
/
src
.
size
(
reduce_dim
);
auto
E_1
=
index
.
numel
()
/
src
.
size
(
reduce_dim
);
auto
E_2
=
src
.
size
(
reduce_dim
);
auto
E_2
=
src
.
size
(
reduce_dim
);
auto
K
=
src
.
numel
()
/
index
.
numel
();
auto
K
=
src
.
numel
()
/
index
.
numel
();
...
@@ -202,12 +201,12 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -202,12 +201,12 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
for
(
int
e_1
=
0
;
e_1
<
E_1
;
e_1
++
)
{
for
(
int
e_1
=
0
;
e_1
<
E_1
;
e_1
++
)
{
int
offset
=
IndexToOffset
<
int64_t
>::
get
(
e_1
*
E_2
,
index_info
);
int
offset
=
IndexToOffset
<
int64_t
>::
get
(
e_1
*
E_2
,
index_info
);
idx
=
index_info
.
data
[
offset
];
idx
=
index_info
.
data
[
offset
];
row_start
=
0
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
vals
[
k
]
=
out_data
[
e_1
*
N
*
K
+
k
];
vals
[
k
]
=
out_data
[
e_1
*
N
*
K
+
k
];
}
}
row_start
=
0
;
for
(
int
e_2
=
0
;
e_2
<
E_2
;
e_2
++
)
{
for
(
int
e_2
=
0
;
e_2
<
E_2
;
e_2
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
...
@@ -224,6 +223,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -224,6 +223,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
}
}
}
else
{
}
else
{
next_idx
=
index_info
.
data
[
offset
+
(
e_2
+
1
)
*
stride
];
next_idx
=
index_info
.
data
[
offset
+
(
e_2
+
1
)
*
stride
];
assert
(
idx
<=
next_idx
);
if
(
idx
!=
next_idx
)
{
if
(
idx
!=
next_idx
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
...
...
test/test_gather.py
View file @
80a7dc52
...
@@ -5,10 +5,7 @@ import torch
...
@@ -5,10 +5,7 @@ import torch
from
torch.autograd
import
gradcheck
from
torch.autograd
import
gradcheck
from
torch_scatter
import
gather_coo
,
gather_csr
from
torch_scatter
import
gather_coo
,
gather_csr
from
.utils
import
tensor
from
.utils
import
tensor
,
dtypes
,
devices
dtypes
=
[
torch
.
float
]
devices
=
[
torch
.
device
(
'cuda'
)]
tests
=
[
tests
=
[
{
{
...
@@ -50,7 +47,6 @@ tests = [
...
@@ -50,7 +47,6 @@ tests = [
]
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_forward
(
test
,
dtype
,
device
):
def
test_forward
(
test
,
dtype
,
device
):
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
...
@@ -65,7 +61,6 @@ def test_forward(test, dtype, device):
...
@@ -65,7 +61,6 @@ def test_forward(test, dtype, device):
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,device'
,
product
(
tests
,
devices
))
@
pytest
.
mark
.
parametrize
(
'test,device'
,
product
(
tests
,
devices
))
def
test_backward
(
test
,
device
):
def
test_backward
(
test
,
device
):
src
=
tensor
(
test
[
'src'
],
torch
.
double
,
device
)
src
=
tensor
(
test
[
'src'
],
torch
.
double
,
device
)
...
@@ -77,9 +72,8 @@ def test_backward(test, device):
...
@@ -77,9 +72,8 @@ def test_backward(test, device):
assert
gradcheck
(
gather_csr
,
(
src
,
indptr
,
None
))
is
True
assert
gradcheck
(
gather_csr
,
(
src
,
indptr
,
None
))
is
True
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_
segment
_out
(
test
,
dtype
,
device
):
def
test_
gather
_out
(
test
,
dtype
,
device
):
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
index
=
tensor
(
test
[
'index'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
indptr
=
tensor
(
test
[
'indptr'
],
torch
.
long
,
device
)
...
@@ -98,7 +92,6 @@ def test_segment_out(test, dtype, device):
...
@@ -98,7 +92,6 @@ def test_segment_out(test, dtype, device):
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
@
pytest
.
mark
.
parametrize
(
'test,dtype,device'
,
product
(
tests
,
dtypes
,
devices
))
def
test_non_contiguous_segment
(
test
,
dtype
,
device
):
def
test_non_contiguous_segment
(
test
,
dtype
,
device
):
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
src
=
tensor
(
test
[
'src'
],
dtype
,
device
)
...
...
test/test_segment.py
View file @
80a7dc52
...
@@ -5,13 +5,11 @@ import torch
...
@@ -5,13 +5,11 @@ import torch
from
torch.autograd
import
gradcheck
from
torch.autograd
import
gradcheck
from
torch_scatter
import
segment_coo
,
segment_csr
from
torch_scatter
import
segment_coo
,
segment_csr
from
.utils
import
tensor
,
dtypes
from
.utils
import
tensor
,
dtypes
,
devices
reductions
=
[
'add'
,
'mean'
,
'min'
,
'max'
]
reductions
=
[
'add'
,
'mean'
,
'min'
,
'max'
]
grad_reductions
=
[
'add'
,
'mean'
]
grad_reductions
=
[
'add'
,
'mean'
]
devices
=
[
torch
.
device
(
'cpu'
)]
tests
=
[
tests
=
[
{
{
'src'
:
[
1
,
2
,
3
,
4
,
5
,
6
],
'src'
:
[
1
,
2
,
3
,
4
,
5
,
6
],
...
@@ -105,7 +103,6 @@ def test_forward(test, reduce, dtype, device):
...
@@ -105,7 +103,6 @@ def test_forward(test, reduce, dtype, device):
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,reduce,device'
,
@
pytest
.
mark
.
parametrize
(
'test,reduce,device'
,
product
(
tests
,
grad_reductions
,
devices
))
product
(
tests
,
grad_reductions
,
devices
))
def
test_backward
(
test
,
reduce
,
device
):
def
test_backward
(
test
,
reduce
,
device
):
...
...
torch_scatter/segment.py
View file @
80a7dc52
...
@@ -56,12 +56,20 @@ class SegmentCOO(torch.autograd.Function):
...
@@ -56,12 +56,20 @@ class SegmentCOO(torch.autograd.Function):
grad_src
=
None
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
reduce
==
'add'
:
if
ctx
.
reduce
==
'add'
:
grad_src
=
gat
(
grad_out
).
gather_coo
(
grad_src
=
gat
(
grad_out
.
is_cuda
).
gather_coo
(
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
elif
ctx
.
reduce
==
'mean'
:
elif
ctx
.
reduce
==
'mean'
:
grad_src
=
gat
(
grad_out
).
gather_coo
(
grad_src
=
gat
(
grad_out
.
is_cuda
).
gather_coo
(
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
grad_out
,
index
,
grad_out
.
new_empty
(
src_size
))
count
=
arg_out
count
=
arg_out
# Gets pre-computed on GPU but not on CPU.
if
count
is
None
:
size
=
list
(
index
.
size
())
size
[
-
1
]
=
grad_out
.
size
(
index
.
dim
()
-
1
)
count
=
segment_cpu
.
segment_coo
(
torch
.
ones_like
(
index
,
dtype
=
grad_out
.
dtype
),
index
,
grad_out
.
new_zeros
(
size
),
'add'
)[
0
].
clamp_
(
min
=
1
)
count
=
gat
(
grad_out
.
is_cuda
).
gather_coo
(
count
=
gat
(
grad_out
.
is_cuda
).
gather_coo
(
count
,
index
,
count
.
new_empty
(
src_size
[:
index
.
dim
()]))
count
,
index
,
count
.
new_empty
(
src_size
[:
index
.
dim
()]))
for
_
in
range
(
grad_out
.
dim
()
-
index
.
dim
()):
for
_
in
range
(
grad_out
.
dim
()
-
index
.
dim
()):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment