Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
48024c15
Commit
48024c15
authored
Jan 11, 2020
by
rusty1s
Browse files
added cpu segment implementation
parent
5817fb9d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
287 additions
and
13 deletions
+287
-13
benchmark/scatter_segment.py
benchmark/scatter_segment.py
+3
-3
cpu/index_info.h
cpu/index_info.h
+65
-0
cpu/segment.cpp
cpu/segment.cpp
+216
-4
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+2
-2
test/test_segment.py
test/test_segment.py
+1
-4
No files found.
benchmark/scatter_segment.py
View file @
48024c15
...
@@ -11,9 +11,6 @@ import torch_scatter
...
@@ -11,9 +11,6 @@ import torch_scatter
from
torch_scatter
import
scatter_add
,
scatter_mean
,
scatter_min
,
scatter_max
from
torch_scatter
import
scatter_add
,
scatter_mean
,
scatter_min
,
scatter_max
from
torch_scatter
import
segment_coo
,
segment_csr
from
torch_scatter
import
segment_coo
,
segment_csr
iters
=
20
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
short_rows
=
[
short_rows
=
[
(
'DIMACS10'
,
'citationCiteseer'
),
(
'DIMACS10'
,
'citationCiteseer'
),
(
'SNAP'
,
'web-Stanford'
),
(
'SNAP'
,
'web-Stanford'
),
...
@@ -216,6 +213,9 @@ if __name__ == '__main__':
...
@@ -216,6 +213,9 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
dense_reduce
=
'sum'
if
args
.
reduce
==
'add'
else
args
.
reduce
args
.
dense_reduce
=
'sum'
if
args
.
reduce
==
'add'
else
args
.
reduce
iters
=
1
if
args
.
device
==
'cpu'
else
20
sizes
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
sizes
=
sizes
[:
3
]
if
args
.
device
==
'cpu'
else
sizes
for
_
in
range
(
10
):
# Warmup.
for
_
in
range
(
10
):
# Warmup.
torch
.
randn
(
100
,
100
,
device
=
args
.
device
).
sum
()
torch
.
randn
(
100
,
100
,
device
=
args
.
device
).
sum
()
...
...
cpu/index_info.h
0 → 100644
View file @
48024c15
#pragma once
#include <torch/extension.h>
#include "compat.h"
#define MAX_TENSORINFO_DIMS 25
template
<
typename
scalar_t
>
struct
TensorInfo
{
TensorInfo
(
scalar_t
*
p
,
int
dim
,
int
sz
[
MAX_TENSORINFO_DIMS
],
int
st
[
MAX_TENSORINFO_DIMS
])
{
data
=
p
;
dims
=
dim
;
AT_ASSERT
(
dims
<
MAX_TENSORINFO_DIMS
);
for
(
int
i
=
0
;
i
<
dim
;
++
i
)
{
sizes
[
i
]
=
sz
[
i
];
strides
[
i
]
=
st
[
i
];
}
}
scalar_t
*
data
;
int
dims
;
int
sizes
[
MAX_TENSORINFO_DIMS
];
int
strides
[
MAX_TENSORINFO_DIMS
];
};
template
<
typename
scalar_t
>
TensorInfo
<
scalar_t
>
getTensorInfo
(
const
at
::
Tensor
&
tensor
)
{
int
sizes
[
MAX_TENSORINFO_DIMS
];
int
strides
[
MAX_TENSORINFO_DIMS
];
int
dims
=
tensor
.
dim
();
for
(
int
i
=
0
;
i
<
dims
;
++
i
)
{
sizes
[
i
]
=
tensor
.
size
(
i
);
strides
[
i
]
=
tensor
.
stride
(
i
);
}
return
TensorInfo
<
scalar_t
>
(
tensor
.
DATA_PTR
<
scalar_t
>
(),
dims
,
sizes
,
strides
);
}
template
<
typename
scalar_t
>
struct
IndexToOffset
{
static
inline
int
get
(
int
idx
,
const
TensorInfo
<
scalar_t
>
&
info
)
{
int
offset
=
0
;
for
(
int
i
=
info
.
dims
-
1
;
i
>=
0
;
--
i
)
{
offset
+=
(
idx
%
info
.
sizes
[
i
])
*
info
.
strides
[
i
];
idx
/=
info
.
sizes
[
i
];
}
return
offset
;
}
};
template
<
typename
scalar_t
>
struct
IndexPtrToOffset
{
static
inline
int
get
(
int
idx
,
const
TensorInfo
<
scalar_t
>
&
info
)
{
int
offset
=
idx
%
(
info
.
sizes
[
info
.
dims
-
1
]
-
1
);
offset
*=
info
.
strides
[
info
.
dims
-
1
];
idx
/=
info
.
sizes
[
info
.
dims
-
1
]
-
1
;
for
(
int
i
=
info
.
dims
-
2
;
i
>=
0
;
--
i
)
{
offset
+=
(
idx
%
info
.
sizes
[
i
])
*
info
.
strides
[
i
];
idx
/=
info
.
sizes
[
i
];
}
return
offset
;
}
};
cpu/segment.cpp
View file @
48024c15
#include <torch/extension.h>
#include <torch/extension.h>
#include "compat.h"
#include "index_info.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
enum
ReductionType
{
ADD
,
MEAN
,
MIN
,
MAX
};
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
if (reduce == "add") { \
const ReductionType REDUCE = ADD; \
return __VA_ARGS__(); \
} else if (reduce == "mean") { \
const ReductionType REDUCE = MEAN; \
return __VA_ARGS__(); \
} else if (reduce == "min") { \
const ReductionType REDUCE = MIN; \
return __VA_ARGS__(); \
} else if (reduce == "max") { \
const ReductionType REDUCE = MAX; \
return __VA_ARGS__(); \
} \
}()
template
<
typename
scalar_t
,
ReductionType
REDUCE
>
struct
Reducer
{
static
inline
scalar_t
init
()
{
if
(
REDUCE
==
MIN
)
{
return
std
::
numeric_limits
<
scalar_t
>::
max
();
}
else
if
(
REDUCE
==
MAX
)
{
return
std
::
numeric_limits
<
scalar_t
>::
lowest
();
}
else
{
return
(
scalar_t
)
0
;
}
}
static
inline
void
update
(
scalar_t
*
val
,
scalar_t
new_val
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
*
val
=
*
val
+
new_val
;
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
*
val
=
new_val
;
}
}
static
inline
void
update
(
scalar_t
*
val
,
scalar_t
new_val
,
int64_t
*
arg
,
int64_t
new_arg
)
{
if
(
REDUCE
==
ADD
||
REDUCE
==
MEAN
)
{
*
val
=
*
val
+
new_val
;
}
else
if
((
REDUCE
==
MIN
&&
new_val
<
*
val
)
||
(
REDUCE
==
MAX
&&
new_val
>
*
val
))
{
*
val
=
new_val
;
*
arg
=
new_arg
;
}
}
static
inline
void
write
(
scalar_t
*
address
,
scalar_t
val
,
int64_t
*
arg_address
,
int64_t
arg
,
int
count
)
{
if
(
REDUCE
==
ADD
)
{
*
address
=
val
;
}
else
if
(
REDUCE
==
MEAN
)
{
*
address
=
val
/
(
count
>
0
?
count
:
(
scalar_t
)
1
);
}
else
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
if
(
count
>
0
)
{
*
address
=
val
;
*
arg_address
=
arg
;
}
else
{
*
address
=
(
scalar_t
)
0
;
}
}
}
};
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
,
segment_csr
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
std
::
string
reduce
)
{
...
@@ -9,8 +79,74 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
...
@@ -9,8 +79,74 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
CHECK_CPU
(
indptr
);
CHECK_CPU
(
indptr
);
if
(
out_opt
.
has_value
())
if
(
out_opt
.
has_value
())
CHECK_CPU
(
out_opt
.
value
());
CHECK_CPU
(
out_opt
.
value
());
AT_ASSERTM
(
false
,
"Not yet implemented"
);
return
std
::
make_tuple
(
src
,
at
::
nullopt
);
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
(),
"Input mismatch"
);
// Broadcasting `indptr` via `expand`.
auto
sizes
=
indptr
.
sizes
().
vec
();
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
{
sizes
[
i
]
=
src
.
size
(
i
);
}
indptr
=
indptr
.
expand
(
sizes
);
src
=
src
.
contiguous
();
auto
reduce_dim
=
indptr
.
dim
()
-
1
;
at
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
reduce_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
),
"Input mismatch"
);
AT_ASSERTM
(
out
.
size
(
reduce_dim
)
==
indptr
.
size
(
reduce_dim
)
-
1
,
"Input mismatch"
);
}
else
{
sizes
=
src
.
sizes
().
vec
();
sizes
[
reduce_dim
]
=
indptr
.
size
(
reduce_dim
)
-
1
;
out
=
at
::
empty
(
sizes
,
src
.
options
());
}
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce
==
"min"
||
reduce
==
"max"
)
{
arg_out
=
at
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
indptr
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
auto
N
=
out
.
size
(
reduce_dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
auto
K
=
out
.
numel
()
/
N
;
auto
E
=
src
.
size
(
reduce_dim
);
auto
indptr_info
=
getTensorInfo
<
int64_t
>
(
indptr
);
auto
stride
=
indptr_info
.
strides
[
indptr_info
.
dims
-
1
];
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_csr"
,
[
&
]
{
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
val
;
int64_t
row_start
,
row_end
,
arg
;
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
int
offset
=
IndexPtrToOffset
<
int64_t
>::
get
(
n
,
indptr_info
);
row_start
=
indptr_info
.
data
[
offset
];
row_end
=
indptr_info
.
data
[
offset
+
stride
];
offset
=
(
n
/
(
indptr
.
size
(
-
1
)
-
1
))
*
E
*
K
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
val
=
Reducer
<
scalar_t
,
REDUCE
>::
init
();
for
(
int64_t
e
=
row_start
;
e
<
row_end
;
e
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
offset
+
e
*
K
+
k
],
&
arg
,
e
);
}
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
n
*
K
+
k
,
val
,
arg_out_data
+
n
*
K
+
k
,
arg
,
row_end
-
row_start
);
}
}
});
});
return
std
::
make_tuple
(
out
,
arg_out
);
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
...
@@ -19,8 +155,84 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -19,8 +155,84 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
CHECK_CPU
(
src
);
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
CHECK_CPU
(
index
);
CHECK_CPU
(
out
);
CHECK_CPU
(
out
);
AT_ASSERTM
(
false
,
"Not yet implemented"
);
return
std
::
make_tuple
(
src
,
at
::
nullopt
);
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
(),
"Input mismatch"
);
// Broadcasting `index` via `expand`.
auto
sizes
=
index
.
sizes
().
vec
();
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
{
sizes
[
i
]
=
src
.
size
(
i
);
}
index
=
index
.
expand
(
sizes
);
src
=
src
.
contiguous
();
out
=
out
.
contiguous
();
auto
reduce_dim
=
index
.
dim
()
-
1
;
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
reduce_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
),
"Input mismatch"
);
at
::
optional
<
at
::
Tensor
>
arg_out
=
at
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce
==
"min"
||
reduce
==
"max"
)
{
arg_out
=
at
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
index
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
auto
E_1
=
index
.
numel
()
/
src
.
size
(
reduce_dim
);
auto
E_2
=
src
.
size
(
reduce_dim
);
auto
K
=
src
.
numel
()
/
index
.
numel
();
auto
N
=
out
.
size
(
reduce_dim
);
auto
index_info
=
getTensorInfo
<
int64_t
>
(
index
);
auto
stride
=
index_info
.
strides
[
index_info
.
dims
-
1
];
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_coo"
,
[
&
]
{
auto
src_data
=
src
.
DATA_PTR
<
scalar_t
>
();
auto
out_data
=
out
.
DATA_PTR
<
scalar_t
>
();
scalar_t
val
;
int64_t
idx
,
next_idx
,
row_start
,
arg
;
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
for
(
int
e_1
=
0
;
e_1
<
E_1
;
e_1
++
)
{
int
offset
=
IndexToOffset
<
int64_t
>::
get
(
e_1
*
E_2
,
index_info
);
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
idx
=
index_info
.
data
[
offset
];
row_start
=
0
;
val
=
out_data
[
e_1
*
N
*
K
+
k
];
for
(
int
e_2
=
0
;
e_2
<
E_2
;
e_2
++
)
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
e_1
*
E_2
*
K
+
e_2
*
K
+
k
],
&
arg
,
e_2
);
if
(
e_2
==
E_2
-
1
)
{
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
val
,
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
arg
,
e_2
+
1
-
row_start
);
}
else
{
next_idx
=
index_info
.
data
[
offset
+
(
e_2
+
1
)
*
stride
];
if
(
idx
!=
next_idx
)
{
Reducer
<
scalar_t
,
REDUCE
>::
write
(
out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
val
,
arg_out_data
+
e_1
*
N
*
K
+
idx
*
K
+
k
,
arg
,
e_2
+
1
-
row_start
);
row_start
=
e_2
+
1
;
val
=
out_data
[
e_1
*
N
*
K
+
next_idx
*
K
+
k
];
}
idx
=
next_idx
;
}
}
}
}
});
});
return
std
::
make_tuple
(
out
,
arg_out
);
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
cuda/segment_kernel.cu
View file @
48024c15
...
@@ -178,7 +178,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
...
@@ -178,7 +178,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
(),
"Input mismatch"
);
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
(),
"Input mismatch"
);
// Broadcasting
across
`ind
ex
` via `expand`.
// Broadcasting `ind
ptr
` via `expand`.
auto
sizes
=
indptr
.
sizes
().
vec
();
auto
sizes
=
indptr
.
sizes
().
vec
();
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
{
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
{
sizes
[
i
]
=
src
.
size
(
i
);
sizes
[
i
]
=
src
.
size
(
i
);
...
@@ -379,7 +379,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
...
@@ -379,7 +379,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
(),
"Input mismatch"
);
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
(),
"Input mismatch"
);
// Broadcasting
across
`index` via `expand`.
// Broadcasting `index` via `expand`.
auto
sizes
=
index
.
sizes
().
vec
();
auto
sizes
=
index
.
sizes
().
vec
();
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
{
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
{
sizes
[
i
]
=
src
.
size
(
i
);
sizes
[
i
]
=
src
.
size
(
i
);
...
...
test/test_segment.py
View file @
48024c15
...
@@ -10,7 +10,7 @@ from .utils import tensor, dtypes
...
@@ -10,7 +10,7 @@ from .utils import tensor, dtypes
reductions
=
[
'add'
,
'mean'
,
'min'
,
'max'
]
reductions
=
[
'add'
,
'mean'
,
'min'
,
'max'
]
grad_reductions
=
[
'add'
,
'mean'
]
grad_reductions
=
[
'add'
,
'mean'
]
devices
=
[
torch
.
device
(
'cu
da
'
)]
devices
=
[
torch
.
device
(
'c
p
u'
)]
tests
=
[
tests
=
[
{
{
...
@@ -82,7 +82,6 @@ tests = [
...
@@ -82,7 +82,6 @@ tests = [
]
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
product
(
tests
,
reductions
,
dtypes
,
devices
))
product
(
tests
,
reductions
,
dtypes
,
devices
))
def
test_forward
(
test
,
reduce
,
dtype
,
device
):
def
test_forward
(
test
,
reduce
,
dtype
,
device
):
...
@@ -119,7 +118,6 @@ def test_backward(test, reduce, device):
...
@@ -119,7 +118,6 @@ def test_backward(test, reduce, device):
assert
gradcheck
(
segment_csr
,
(
src
,
indptr
,
None
,
reduce
))
is
True
assert
gradcheck
(
segment_csr
,
(
src
,
indptr
,
None
,
reduce
))
is
True
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
product
(
tests
,
reductions
,
dtypes
,
devices
))
product
(
tests
,
reductions
,
dtypes
,
devices
))
def
test_segment_out
(
test
,
reduce
,
dtype
,
device
):
def
test_segment_out
(
test
,
reduce
,
dtype
,
device
):
...
@@ -153,7 +151,6 @@ def test_segment_out(test, reduce, dtype, device):
...
@@ -153,7 +151,6 @@ def test_segment_out(test, reduce, dtype, device):
assert
torch
.
all
(
out
==
expected
)
assert
torch
.
all
(
out
==
expected
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
@
pytest
.
mark
.
parametrize
(
'test,reduce,dtype,device'
,
product
(
tests
,
reductions
,
dtypes
,
devices
))
product
(
tests
,
reductions
,
dtypes
,
devices
))
def
test_non_contiguous_segment
(
test
,
reduce
,
dtype
,
device
):
def
test_non_contiguous_segment
(
test
,
reduce
,
dtype
,
device
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment