Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
8a1f0741
Commit
8a1f0741
authored
Feb 11, 2020
by
rusty1s
Browse files
update
parent
eb7da512
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
284 additions
and
305 deletions
+284
-305
.travis.yml
.travis.yml
+1
-2
csrc/cpu/scatter_cpu.cpp
csrc/cpu/scatter_cpu.cpp
+65
-67
csrc/scatter.cpp
csrc/scatter.cpp
+206
-18
csrc/scatter.cpp2
csrc/scatter.cpp2
+0
-216
csrc/segment_coo.cpp
csrc/segment_coo.cpp
+5
-0
csrc/segment_csr.cpp
csrc/segment_csr.cpp
+5
-0
setup.py
setup.py
+1
-1
torch_scatter/__init__.py
torch_scatter/__init__.py
+1
-1
No files found.
.travis.yml
View file @
8a1f0741
...
...
@@ -3,7 +3,6 @@ sudo: enabled
env
:
global
:
-
TORCH_VERSION=1.4.0
-
CIBW_BUILD=cp36-*
# jobs:
# - FORCE_CUDA=0 TORCH=${TORCH_VERSION}+cpu
# - FORCE_CUDA=1 CUDA_SHORT=9.2 CUDA=9.2.148-1 UBUNTU_VERSION=ubuntu1604 CUBLAS=cuda-cublas-dev-9-2 TORCH=${TORCH_VERSION}+cu92
...
...
@@ -110,7 +109,7 @@ install:
script
:
-
flake8 .
-
python3 setup.py test || python setup.py install
-
python3 setup.py
s
dist bdist_wheel
-
python3 setup.py
b
dist
_wheel || python3 setup.py
bdist_wheel
-
ls dist
notifications
:
email
:
false
csrc/cpu/scatter_cpu.cpp
View file @
8a1f0741
#include "scatter_cpu.h"
//
#include "index_info.h"
//
#include "reducer.h"
//
#include "utils.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
scatter_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
dim
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
,
std
::
string
reduce
)
{
return
std
::
make_tuple
(
src
,
optional_out
);
// CHECK_CPU(src);
// CHECK_CPU(index);
// if (optional_out.has_value())
// CHECK_CPU(optional_out.value());
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
if
(
optional_out
.
has_value
())
CHECK_CPU
(
optional_out
.
value
());
//
CHECK_INPUT(src.dim() == index.dim());
//
for (auto i = 0; i < index.dim() - 1; i++)
//
CHECK_INPUT(src.size(i) >= index.size(i));
CHECK_INPUT
(
src
.
dim
()
==
index
.
dim
());
for
(
auto
i
=
0
;
i
<
index
.
dim
()
-
1
;
i
++
)
CHECK_INPUT
(
src
.
size
(
i
)
>=
index
.
size
(
i
));
//
src = src.contiguous();
src
=
src
.
contiguous
();
//
torch::Tensor out;
//
if (optional_out.has_value()) {
//
out = optional_out.value().contiguous();
//
for (auto i = 0; i < out.dim(); i++)
//
if (i != dim)
//
CHECK_INPUT(src.size(i) == out.size(i));
//
} else {
//
auto sizes = src.sizes().vec();
//
if (dim_size.has_value())
//
sizes[dim] = dim_size.value();
//
else if (index.numel() == 0)
//
sizes[dim] = 0;
//
else
//
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
//
out = torch::empty(sizes, src.options());
//
}
torch
::
Tensor
out
;
if
(
optional_out
.
has_value
())
{
out
=
optional_out
.
value
().
contiguous
();
for
(
auto
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
dim
)
CHECK_INPUT
(
src
.
size
(
i
)
==
out
.
size
(
i
));
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
if
(
dim_size
.
has_value
())
sizes
[
dim
]
=
dim_size
.
value
();
else
if
(
index
.
numel
()
==
0
)
sizes
[
dim
]
=
0
;
else
sizes
[
dim
]
=
1
+
*
index
.
max
().
data_ptr
<
int64_t
>
();
out
=
torch
::
empty
(
sizes
,
src
.
options
());
}
//
torch::optional<torch::Tensor> arg_out = torch::nullopt;
//
int64_t *arg_out_data = nullptr;
//
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
//
arg_out = torch::full_like(out, src.size(dim), index.options());
//
arg_out_data = arg_out.value().data_ptr<int64_t>();
//
}
torch
::
optional
<
torch
::
Tensor
>
arg_out
=
torch
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
torch
::
full_like
(
out
,
src
.
size
(
dim
),
index
.
options
());
arg_out_data
=
arg_out
.
value
().
data_ptr
<
int64_t
>
();
}
//
if (index.numel() == 0)
//
return std::make_tuple(out, arg_out);
if
(
index
.
numel
()
==
0
)
return
std
::
make_tuple
(
out
,
arg_out
);
//
auto B = 1;
//
for (auto i = 0; i < dim; i++)
//
B *= src.size(i);
//
auto E = src.size(dim);
//
auto K = src.numel() / (B * E);
//
auto N = out.size(dim);
auto
B
=
1
;
for
(
auto
i
=
0
;
i
<
dim
;
i
++
)
B
*=
src
.
size
(
i
);
auto
E
=
src
.
size
(
dim
);
auto
K
=
src
.
numel
()
/
(
B
*
E
);
auto
N
=
out
.
size
(
dim
);
//
auto index_info = getTensorInfo<int64_t>(index);
//
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
//
auto src_data = src.data_ptr<scalar_t>();
//
auto out_data = out.data_ptr<scalar_t>();
auto
index_info
=
getTensorInfo
<
int64_t
>
(
index
);
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter"
,
[
&
]
{
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
//
int64_t i, idx;
//
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
//
if (!optional_out.has_value())
//
out.fill_(Reducer<scalar_t>::init(REDUCE));
int64_t
i
,
idx
;
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
if
(
!
optional_out
.
has_value
())
out
.
fill_
(
Reducer
<
scalar_t
>::
init
(
REDUCE
));
//
for (auto b = 0; b < B; b++) {
//
for (auto e = 0; e < E; e++) {
//
for (auto k = 0; k < K; k++) {
//
i = b * E * K + e * K + k;
//
idx = index_info.data[IndexToOffset<int64_t>::get(i,
//
index_info)];
Reducer<scalar_t>::update(
//
REDUCE, out_data + b * N * K + idx * K + k, src_data[i],
//
arg_out_data + b * N * K + idx * K + k, e);
//
}
//
}
//
}
for
(
auto
b
=
0
;
b
<
B
;
b
++
)
{
for
(
auto
e
=
0
;
e
<
E
;
e
++
)
{
for
(
auto
k
=
0
;
k
<
K
;
k
++
)
{
i
=
b
*
E
*
K
+
e
*
K
+
k
;
idx
=
index_info
.
data
[
IndexToOffset
<
int64_t
>::
get
(
i
,
index_info
)];
Reducer
<
scalar_t
>::
update
(
REDUCE
,
out_data
+
b
*
N
*
K
+
idx
*
K
+
k
,
src_data
[
i
],
arg_out_data
+
b
*
N
*
K
+
idx
*
K
+
k
,
e
);
}
}
}
// if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
// out.masked_fill_(out == Reducer<scalar_t>::init(REDUCE),
// (scalar_t)0);
// });
// });
if
(
!
optional_out
.
has_value
()
&&
(
REDUCE
==
MIN
||
REDUCE
==
MAX
))
out
.
masked_fill_
(
out
==
Reducer
<
scalar_t
>::
init
(
REDUCE
),
(
scalar_t
)
0
);
});
});
//
return std::make_tuple(out, arg_out);
return
std
::
make_tuple
(
out
,
arg_out
);
}
csrc/scatter.cpp
View file @
8a1f0741
#include <Python.h>
#include <torch/script.h>
//
#include "cpu/scatter_cpu.h"
//
#include "utils.h"
#include "cpu/scatter_cpu.h"
#include "utils.h"
// #ifdef WITH_CUDA
// #include <cuda.h>
// #include "cuda/scatter_cuda.h"
// #endif
#ifdef WITH_CUDA
#include "cuda/scatter_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__scatter
(
void
)
{
return
NULL
;
}
#endif
torch
::
Tensor
broadcast
(
torch
::
Tensor
src
,
torch
::
Tensor
other
,
int64_t
dim
)
{
if
(
src
.
dim
()
==
1
)
for
(
auto
i
=
0
;
i
<
dim
;
i
++
)
src
=
src
.
unsqueeze
(
0
);
for
(
auto
i
=
src
.
dim
();
i
<
other
.
dim
();
i
++
)
src
=
src
.
unsqueeze
(
-
1
);
src
=
src
.
expand
(
other
.
sizes
().
vec
());
return
src
;
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
scatter_fw
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
dim
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
,
std
::
string
reduce
)
{
return
std
::
make_tuple
(
src
,
optional_out
);
// if (src.device().is_cuda()) {
// #ifdef WITH_CUDA
// return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
// #else
// AT_ERROR("Not compiled with CUDA support");
// #endif
// } else {
// return scatter_cpu(src, index, dim, optional_out, dim_size, reduce);
// }
if
(
src
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
scatter_cuda
(
src
,
index
,
dim
,
optional_out
,
dim_size
,
reduce
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
scatter_cpu
(
src
,
index
,
dim
,
optional_out
,
dim_size
,
reduce
);
}
}
using
torch
::
autograd
::
AutogradContext
;
using
torch
::
autograd
::
Variable
;
using
torch
::
autograd
::
variable_list
;
class
ScatterSum
:
public
torch
::
autograd
::
Function
<
ScatterSum
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
src
,
Variable
index
,
int64_t
dim
,
torch
::
optional
<
Variable
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
dim
=
dim
<
0
?
src
.
dim
()
+
dim
:
dim
;
ctx
->
saved_data
[
"dim"
]
=
dim
;
ctx
->
saved_data
[
"src_shape"
]
=
src
.
sizes
();
index
=
broadcast
(
index
,
src
,
dim
);
auto
result
=
scatter_fw
(
src
,
index
,
dim
,
optional_out
,
dim_size
,
"sum"
);
auto
out
=
std
::
get
<
0
>
(
result
);
ctx
->
save_for_backward
({
index
});
if
(
optional_out
.
has_value
())
ctx
->
mark_dirty
({
optional_out
.
value
()});
return
{
out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
dim
=
ctx
->
saved_data
[
"dim"
].
toInt
();
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toIntList
());
auto
grad_in
=
torch
::
gather
(
grad_out
,
dim
,
index
,
false
);
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
(),
Variable
()};
}
};
class
ScatterMean
:
public
torch
::
autograd
::
Function
<
ScatterMean
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
src
,
Variable
index
,
int64_t
dim
,
torch
::
optional
<
Variable
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
dim
=
dim
<
0
?
src
.
dim
()
+
dim
:
dim
;
ctx
->
saved_data
[
"dim"
]
=
dim
;
ctx
->
saved_data
[
"src_shape"
]
=
src
.
sizes
();
auto
old_index
=
index
;
index
=
broadcast
(
index
,
src
,
dim
);
auto
result
=
scatter_fw
(
src
,
index
,
dim
,
optional_out
,
dim_size
,
"sum"
);
auto
out
=
std
::
get
<
0
>
(
result
);
auto
ones
=
torch
::
ones
(
old_index
.
sizes
(),
src
.
options
());
result
=
scatter_fw
(
ones
,
old_index
,
old_index
.
dim
()
<=
dim
?
old_index
.
dim
()
-
1
:
dim
,
torch
::
nullopt
,
out
.
size
(
dim
),
"sum"
);
auto
count
=
std
::
get
<
0
>
(
result
);
count
.
clamp_
(
1
);
count
=
broadcast
(
count
,
out
,
dim
);
out
.
div_
(
count
);
ctx
->
save_for_backward
({
index
,
count
});
if
(
optional_out
.
has_value
())
ctx
->
mark_dirty
({
optional_out
.
value
()});
return
{
out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
count
=
saved
[
1
];
auto
dim
=
ctx
->
saved_data
[
"dim"
].
toInt
();
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toIntList
());
count
=
torch
::
gather
(
count
,
dim
,
index
,
false
);
auto
grad_in
=
torch
::
gather
(
grad_out
,
dim
,
index
,
false
);
grad_in
.
div_
(
count
);
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
(),
Variable
()};
}
};
class
ScatterMin
:
public
torch
::
autograd
::
Function
<
ScatterMin
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
src
,
Variable
index
,
int64_t
dim
,
torch
::
optional
<
Variable
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
dim
=
dim
<
0
?
src
.
dim
()
+
dim
:
dim
;
ctx
->
saved_data
[
"dim"
]
=
dim
;
ctx
->
saved_data
[
"src_shape"
]
=
src
.
sizes
();
index
=
broadcast
(
index
,
src
,
dim
);
auto
result
=
scatter_fw
(
src
,
index
,
dim
,
optional_out
,
dim_size
,
"min"
);
auto
out
=
std
::
get
<
0
>
(
result
);
auto
arg_out
=
std
::
get
<
1
>
(
result
).
value
();
ctx
->
save_for_backward
({
index
,
arg_out
});
ctx
->
mark_non_differentiable
({
arg_out
});
if
(
optional_out
.
has_value
())
ctx
->
mark_dirty
({
optional_out
.
value
()});
return
{
out
,
arg_out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
arg_out
=
saved
[
1
];
auto
dim
=
ctx
->
saved_data
[
"dim"
].
toInt
();
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toIntList
());
src_shape
[
dim
]
+=
1
;
auto
grad_in
=
torch
::
zeros
(
src_shape
,
grad_out
.
options
());
grad_in
.
scatter_
(
dim
,
arg_out
,
grad_out
);
grad_in
=
grad_in
.
narrow
(
dim
,
0
,
src_shape
[
dim
]
-
1
);
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
(),
Variable
()};
}
};
class
ScatterMax
:
public
torch
::
autograd
::
Function
<
ScatterMax
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
src
,
Variable
index
,
int64_t
dim
,
torch
::
optional
<
Variable
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
dim
=
dim
<
0
?
src
.
dim
()
+
dim
:
dim
;
ctx
->
saved_data
[
"dim"
]
=
dim
;
ctx
->
saved_data
[
"src_shape"
]
=
src
.
sizes
();
index
=
broadcast
(
index
,
src
,
dim
);
auto
result
=
scatter_fw
(
src
,
index
,
dim
,
optional_out
,
dim_size
,
"max"
);
auto
out
=
std
::
get
<
0
>
(
result
);
auto
arg_out
=
std
::
get
<
1
>
(
result
).
value
();
ctx
->
save_for_backward
({
index
,
arg_out
});
ctx
->
mark_non_differentiable
({
arg_out
});
if
(
optional_out
.
has_value
())
ctx
->
mark_dirty
({
optional_out
.
value
()});
return
{
out
,
arg_out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
arg_out
=
saved
[
1
];
auto
dim
=
ctx
->
saved_data
[
"dim"
].
toInt
();
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toIntList
());
src_shape
[
dim
]
+=
1
;
auto
grad_in
=
torch
::
zeros
(
src_shape
,
grad_out
.
options
());
grad_in
.
scatter_
(
dim
,
arg_out
,
grad_out
);
grad_in
=
grad_in
.
narrow
(
dim
,
0
,
src_shape
[
dim
]
-
1
);
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
(),
Variable
()};
}
};
torch
::
Tensor
scatter_sum
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
dim
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
return
ScatterSum
::
apply
(
src
,
index
,
dim
,
optional_out
,
dim_size
)[
0
];
}
torch
::
Tensor
scatter_mean
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
dim
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
return
ScatterMean
::
apply
(
src
,
index
,
dim
,
optional_out
,
dim_size
)[
0
];
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
scatter_min
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
dim
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
auto
result
=
ScatterMin
::
apply
(
src
,
index
,
dim
,
optional_out
,
dim_size
);
return
std
::
make_tuple
(
result
[
0
],
result
[
1
]);
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
scatter_max
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
dim
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
)
{
auto
result
=
ScatterMax
::
apply
(
src
,
index
,
dim
,
optional_out
,
dim_size
);
return
std
::
make_tuple
(
result
[
0
],
result
[
1
]);
}
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_scatter::scatter_fw"
,
&
scatter_fw
);
static
auto
registry
=
torch
::
RegisterOperators
()
.
op
(
"torch_scatter::scatter_sum"
,
&
scatter_sum
)
.
op
(
"torch_scatter::scatter_mean"
,
&
scatter_mean
)
.
op
(
"torch_scatter::scatter_min"
,
&
scatter_min
)
.
op
(
"torch_scatter::scatter_max"
,
&
scatter_max
);
csrc/scatter.cpp2
deleted
100644 → 0
View file @
eb7da512
#include <torch/script.h>
#include "cpu/scatter_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/scatter_cuda.h"
#endif
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (src.dim() == 1)
for (auto i = 0; i < dim; i++)
src = src.unsqueeze(0);
for (auto i = src.dim(); i < other.dim(); i++)
src = src.unsqueeze(-1);
src = src.expand(other.sizes().vec());
return src;
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return scatter_cpu(src, index, dim, optional_out, dim_size, reduce);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ScatterSum : public torch::autograd::Function<ScatterSum> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out, dim, index, false);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMean : public torch::autograd::Function<ScatterMean> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
auto old_index = index;
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
auto ones = torch::ones(old_index.sizes(), src.options());
result = scatter_fw(ones, old_index,
old_index.dim() <= dim ? old_index.dim() - 1 : dim,
torch::nullopt, out.size(dim), "sum");
auto count = std::get<0>(result);
count.clamp_(1);
count = broadcast(count, out, dim);
out.div_(count);
ctx->save_for_backward({index, count});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
count = torch::gather(count, dim, index, false);
auto grad_in = torch::gather(grad_out, dim, index, false);
grad_in.div_(count);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMin : public torch::autograd::Function<ScatterMin> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMax : public torch::autograd::Function<ScatterMax> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
}
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMax::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators()
.op("torch_scatter::scatter_sum", &scatter_sum)
.op("torch_scatter::scatter_mean", &scatter_mean)
.op("torch_scatter::scatter_min", &scatter_min)
.op("torch_scatter::scatter_max", &scatter_max);
csrc/segment_coo.cpp
2
→
csrc/segment_coo.cpp
View file @
8a1f0741
#include <Python.h>
#include <torch/script.h>
#include "cpu/segment_coo_cpu.h"
...
...
@@ -7,6 +8,10 @@
#include "cuda/segment_coo_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__scatter
(
void
)
{
return
NULL
;
}
#endif
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
segment_coo_fw
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
,
...
...
csrc/segment_csr.cpp
2
→
csrc/segment_csr.cpp
View file @
8a1f0741
#include <Python.h>
#include <torch/script.h>
#include "cpu/segment_csr_cpu.h"
...
...
@@ -7,6 +8,10 @@
#include "cuda/segment_csr_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__scatter
(
void
)
{
return
NULL
;
}
#endif
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
segment_csr_fw
(
torch
::
Tensor
src
,
torch
::
Tensor
indptr
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
,
...
...
setup.py
View file @
8a1f0741
...
...
@@ -57,7 +57,7 @@ tests_require = ['pytest', 'pytest-cov']
setup
(
name
=
'torch_scatter'
,
version
=
'2.0.
2
'
,
version
=
'2.0.
3
'
,
author
=
'Matthias Fey'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
url
=
'https://github.com/rusty1s/pytorch_scatter'
,
...
...
torch_scatter/__init__.py
View file @
8a1f0741
...
...
@@ -9,7 +9,7 @@ from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
from
.composite
import
(
scatter_std
,
scatter_logsumexp
,
scatter_softmax
,
scatter_log_softmax
)
__version__
=
'2.0.
2
'
__version__
=
'2.0.
3
'
__all__
=
[
'scatter_sum'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment