Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
62815576
Commit
62815576
authored
Jan 29, 2020
by
rusty1s
Browse files
moved extensions to torch.ops
parent
0a221ab8
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
200 additions
and
175 deletions
+200
-175
cpu/dim_apply.h
cpu/dim_apply.h
+2
-2
cpu/gather.cpp
cpu/gather.cpp
+13
-14
cpu/index_info.h
cpu/index_info.h
+1
-1
cpu/scatter.cpp
cpu/scatter.cpp
+13
-14
cpu/segment.cpp
cpu/segment.cpp
+16
-17
cuda/gather.cpp
cuda/gather.cpp
+14
-14
cuda/gather_kernel.cu
cuda/gather_kernel.cu
+11
-8
cuda/index.cuh
cuda/index.cuh
+1
-1
cuda/indptr.cuh
cuda/indptr.cuh
+1
-1
cuda/scatter.cpp
cuda/scatter.cpp
+22
-22
cuda/scatter_kernel.cu
cuda/scatter_kernel.cu
+9
-7
cuda/segment.cpp
cuda/segment.cpp
+16
-16
cuda/segment_kernel.cu
cuda/segment_kernel.cu
+17
-13
setup.py
setup.py
+6
-3
torch_scatter/__init__.py
torch_scatter/__init__.py
+14
-0
torch_scatter/div.py
torch_scatter/div.py
+6
-6
torch_scatter/gather.py
torch_scatter/gather.py
+20
-19
torch_scatter/max.py
torch_scatter/max.py
+6
-6
torch_scatter/min.py
torch_scatter/min.py
+6
-5
torch_scatter/mul.py
torch_scatter/mul.py
+6
-6
No files found.
cpu/dim_apply.h
View file @
62815576
...
...
@@ -19,7 +19,7 @@
auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros =
a
t::zeros(dims, TENSOR1.options().dtype(
a
t::kLong));
\
auto zeros = t
orch
::zeros(dims, TENSOR1.options().dtype(t
orch
::kLong)); \
auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \
\
...
...
@@ -78,7 +78,7 @@
auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\
auto dims = TENSOR1.dim(); \
auto zeros =
a
t::zeros(dims, TENSOR1.options().dtype(
a
t::kLong));
\
auto zeros = t
orch
::zeros(dims, TENSOR1.options().dtype(t
orch
::kLong)); \
auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \
\
...
...
cpu/gather.cpp
View file @
62815576
#include <torch/
extension
.h>
#include <torch/
script
.h>
#include "compat.h"
#include "index_info.h"
#include <vector>
#define CHECK_CPU(x) AT_ASSERTM(
!
x.
typ
e().is_cu
da
(), #x " must be CPU tensor")
#define CHECK_CPU(x) AT_ASSERTM(x.
devic
e().is_c
p
u(), #x " must be CPU tensor")
a
t
::
Tensor
gather_csr
(
a
t
::
Tensor
src
,
a
t
::
Tensor
indptr
,
at
::
optional
<
a
t
::
Tensor
>
out_opt
)
{
t
orch
::
Tensor
gather_csr
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
indptr
,
torch
::
optional
<
t
orch
::
Tensor
>
out_opt
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
indptr
);
if
(
out_opt
.
has_value
())
...
...
@@ -23,7 +23,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
AT_ASSERTM
(
src
.
size
(
gather_dim
)
==
indptr
.
size
(
gather_dim
)
-
1
,
"Input mismatch"
);
a
t
::
Tensor
out
;
t
orch
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
...
...
@@ -32,7 +32,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
gather_dim
]
=
*
indptr
.
flatten
()[
-
1
].
DATA_PTR
<
int64_t
>
();
out
=
a
t
::
empty
(
sizes
,
src
.
options
());
out
=
t
orch
::
empty
(
sizes
,
src
.
options
());
}
auto
N
=
src
.
size
(
gather_dim
)
*
(
indptr
.
numel
()
/
indptr
.
size
(
-
1
));
...
...
@@ -68,8 +68,8 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
return
out
;
}
a
t
::
Tensor
gather_coo
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
at
::
optional
<
a
t
::
Tensor
>
out_opt
)
{
t
orch
::
Tensor
gather_coo
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
torch
::
optional
<
t
orch
::
Tensor
>
out_opt
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
if
(
out_opt
.
has_value
())
...
...
@@ -82,7 +82,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
src
=
src
.
contiguous
();
auto
gather_dim
=
index
.
dim
()
-
1
;
a
t
::
Tensor
out
;
t
orch
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
...
...
@@ -92,7 +92,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
gather_dim
]
=
index
.
size
(
gather_dim
);
out
=
a
t
::
empty
(
sizes
,
src
.
options
());
out
=
t
orch
::
empty
(
sizes
,
src
.
options
());
}
auto
E_1
=
index
.
numel
()
/
out
.
size
(
gather_dim
);
...
...
@@ -139,7 +139,6 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
return
out
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"gather_csr"
,
&
gather_csr
,
"Gather CSR (CPU)"
);
m
.
def
(
"gather_coo"
,
&
gather_coo
,
"Gather COO (CPU)"
);
}
static
auto
registry
=
torch
::
RegisterOperators
(
"torch_scatter_cpu::gather_csr"
,
&
gather_csr
)
.
op
(
"torch_scatter_cpu::gather_coo"
,
&
gather_coo
);
cpu/index_info.h
View file @
62815576
...
...
@@ -26,7 +26,7 @@ template <typename scalar_t> struct TensorInfo {
};
template
<
typename
scalar_t
>
TensorInfo
<
scalar_t
>
getTensorInfo
(
const
a
t
::
Tensor
&
tensor
)
{
TensorInfo
<
scalar_t
>
getTensorInfo
(
const
t
orch
::
Tensor
&
tensor
)
{
int
sizes
[
MAX_TENSORINFO_DIMS
];
int
strides
[
MAX_TENSORINFO_DIMS
];
...
...
cpu/scatter.cpp
View file @
62815576
#include <torch/
extension
.h>
#include <torch/
script
.h>
#include "dim_apply.h"
#define CHECK_CPU(x) AT_ASSERTM(
!
x.
typ
e().is_cu
da
(), #x " must be CPU tensor")
#define CHECK_CPU(x) AT_ASSERTM(x.
devic
e().is_c
p
u(), #x " must be CPU tensor")
void
scatter_mul
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
void
scatter_mul
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
int64_t
dim
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
...
...
@@ -20,7 +20,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
});
}
void
scatter_div
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
void
scatter_div
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
int64_t
dim
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
...
...
@@ -36,8 +36,8 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
});
}
void
scatter_max
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
a
t
::
Tensor
arg
,
int64_t
dim
)
{
void
scatter_max
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
t
orch
::
Tensor
arg
,
int64_t
dim
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
CHECK_CPU
(
out
);
...
...
@@ -56,8 +56,8 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
});
}
void
scatter_min
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
a
t
::
Tensor
arg
,
int64_t
dim
)
{
void
scatter_min
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
t
orch
::
Tensor
arg
,
int64_t
dim
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
CHECK_CPU
(
out
);
...
...
@@ -77,9 +77,8 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
});
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"scatter_mul"
,
&
scatter_mul
,
"Scatter Mul (CPU)"
);
m
.
def
(
"scatter_div"
,
&
scatter_div
,
"Scatter Div (CPU)"
);
m
.
def
(
"scatter_max"
,
&
scatter_max
,
"Scatter Max (CPU)"
);
m
.
def
(
"scatter_min"
,
&
scatter_min
,
"Scatter Min (CPU)"
);
}
static
auto
registry
=
torch
::
RegisterOperators
(
"torch_scatter_cpu::scatter_mul"
,
&
scatter_mul
)
.
op
(
"torch_scatter_cpu::scatter_div"
,
&
scatter_div
)
.
op
(
"torch_scatter_cpu::scatter_max"
,
&
scatter_max
)
.
op
(
"torch_scatter_cpu::scatter_min"
,
&
scatter_min
);
cpu/segment.cpp
View file @
62815576
#include <torch/
extension
.h>
#include <torch/
script
.h>
#include "compat.h"
#include "index_info.h"
#include <vector>
#define CHECK_CPU(x) AT_ASSERTM(
!
x.
typ
e().is_cu
da
(), #x " must be CPU tensor")
#define CHECK_CPU(x) AT_ASSERTM(x.
devic
e().is_c
p
u(), #x " must be CPU tensor")
enum
ReductionType
{
SUM
,
MEAN
,
MIN
,
MAX
};
...
...
@@ -74,9 +74,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
}
};
std
::
tuple
<
a
t
::
Tensor
,
a
t
::
optional
<
a
t
::
Tensor
>>
segment_csr
(
a
t
::
Tensor
src
,
a
t
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
std
::
tuple
<
t
orch
::
Tensor
,
t
orch
::
optional
<
t
orch
::
Tensor
>>
segment_csr
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
indptr
,
torch
::
optional
<
torch
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
indptr
);
if
(
out_opt
.
has_value
())
...
...
@@ -94,7 +94,7 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
src
=
src
.
contiguous
();
auto
reduce_dim
=
indptr
.
dim
()
-
1
;
a
t
::
Tensor
out
;
t
orch
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
...
...
@@ -105,13 +105,13 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
}
else
{
sizes
=
src
.
sizes
().
vec
();
sizes
[
reduce_dim
]
=
indptr
.
size
(
reduce_dim
)
-
1
;
out
=
a
t
::
empty
(
sizes
,
src
.
options
());
out
=
t
orch
::
empty
(
sizes
,
src
.
options
());
}
a
t
::
optional
<
a
t
::
Tensor
>
arg_out
=
a
t
::
nullopt
;
t
orch
::
optional
<
t
orch
::
Tensor
>
arg_out
=
t
orch
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
a
t
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
indptr
.
options
());
arg_out
=
t
orch
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
indptr
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
...
...
@@ -156,8 +156,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
return
std
::
make_tuple
(
out
,
arg_out
);
}
std
::
tuple
<
a
t
::
Tensor
,
a
t
::
optional
<
a
t
::
Tensor
>>
segment_coo
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
std
::
tuple
<
t
orch
::
Tensor
,
t
orch
::
optional
<
t
orch
::
Tensor
>>
segment_coo
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
std
::
string
reduce
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
...
...
@@ -180,10 +180,10 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
if
(
i
!=
reduce_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
),
"Input mismatch"
);
a
t
::
optional
<
a
t
::
Tensor
>
arg_out
=
a
t
::
nullopt
;
t
orch
::
optional
<
t
orch
::
Tensor
>
arg_out
=
t
orch
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
a
t
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
index
.
options
());
arg_out
=
t
orch
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
index
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
...
...
@@ -251,7 +251,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
return
std
::
make_tuple
(
out
,
arg_out
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"segment_csr"
,
&
segment_csr
,
"Segment CSR (CPU)"
);
m
.
def
(
"segment_coo"
,
&
segment_coo
,
"Segment COO (CPU)"
);
}
static
auto
registry
=
torch
::
RegisterOperators
(
"torch_scatter_cpu::segment_csr"
,
&
segment_csr
)
.
op
(
"torch_scatter_cpu::segment_coo"
,
&
segment_coo
);
cuda/gather.cpp
View file @
62815576
#include <torch/
extension
.h>
#include <torch/
script
.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
a
t
::
Tensor
gather_csr_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
indptr
,
at
::
optional
<
a
t
::
Tensor
>
out_opt
);
a
t
::
Tensor
gather_coo_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
at
::
optional
<
a
t
::
Tensor
>
out_opt
);
t
orch
::
Tensor
gather_csr_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
indptr
,
torch
::
optional
<
t
orch
::
Tensor
>
out_opt
);
t
orch
::
Tensor
gather_coo_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
torch
::
optional
<
t
orch
::
Tensor
>
out_opt
);
a
t
::
Tensor
gather_csr
(
a
t
::
Tensor
src
,
a
t
::
Tensor
indptr
,
at
::
optional
<
a
t
::
Tensor
>
out_opt
)
{
t
orch
::
Tensor
gather_csr
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
indptr
,
torch
::
optional
<
t
orch
::
Tensor
>
out_opt
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
indptr
);
if
(
out_opt
.
has_value
())
...
...
@@ -16,8 +17,8 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
return
gather_csr_cuda
(
src
,
indptr
,
out_opt
);
}
a
t
::
Tensor
gather_coo
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
at
::
optional
<
a
t
::
Tensor
>
out_opt
)
{
t
orch
::
Tensor
gather_coo
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
torch
::
optional
<
t
orch
::
Tensor
>
out_opt
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
if
(
out_opt
.
has_value
())
...
...
@@ -25,7 +26,6 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
return
gather_coo_cuda
(
src
,
index
,
out_opt
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"gather_csr"
,
&
gather_csr
,
"Gather CSR (CUDA)"
);
m
.
def
(
"gather_coo"
,
&
gather_coo
,
"Gather COO (CUDA)"
);
}
static
auto
registry
=
torch
::
RegisterOperators
(
"torch_scatter_cuda::gather_csr"
,
&
gather_csr
)
.
op
(
"torch_scatter_cuda::gather_coo"
,
&
gather_coo
);
cuda/gather_kernel.cu
View file @
62815576
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
#include "compat.cuh"
#include "indptr.cuh"
...
...
@@ -58,9 +58,10 @@ __global__ void gather_csr_broadcast_kernel(
}
}
a
t
::
Tensor
gather_csr_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
indptr
,
at
::
optional
<
a
t
::
Tensor
>
out_opt
)
{
t
orch
::
Tensor
gather_csr_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
indptr
,
torch
::
optional
<
t
orch
::
Tensor
>
out_opt
)
{
cudaSetDevice
(
src
.
get_device
());
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
(),
"Input mismatch"
);
for
(
int
i
=
0
;
i
<
indptr
.
dim
()
-
1
;
i
++
)
AT_ASSERTM
(
src
.
size
(
i
)
==
indptr
.
size
(
i
),
"Input mismatch"
);
...
...
@@ -70,7 +71,7 @@ at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
AT_ASSERTM
(
src
.
size
(
gather_dim
)
==
indptr
.
size
(
gather_dim
)
-
1
,
"Input mismatch"
);
a
t
::
Tensor
out
;
t
orch
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
...
...
@@ -152,8 +153,10 @@ __global__ void gather_coo_broadcast_kernel(
}
}
at
::
Tensor
gather_coo_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
optional
<
at
::
Tensor
>
out_opt
)
{
torch
::
Tensor
gather_coo_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
optional
<
torch
::
Tensor
>
out_opt
)
{
cudaSetDevice
(
src
.
get_device
());
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
(),
"Input mismatch"
);
for
(
int
i
=
0
;
i
<
index
.
dim
()
-
1
;
i
++
)
...
...
@@ -162,7 +165,7 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
src
=
src
.
contiguous
();
auto
gather_dim
=
index
.
dim
()
-
1
;
a
t
::
Tensor
out
;
t
orch
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
...
...
@@ -172,7 +175,7 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
gather_dim
]
=
index
.
size
(
gather_dim
);
out
=
a
t
::
empty
(
sizes
,
src
.
options
());
out
=
t
orch
::
empty
(
sizes
,
src
.
options
());
}
auto
E
=
index
.
numel
();
...
...
cuda/index.cuh
View file @
62815576
#pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
template
<
typename
scalar1
,
typename
scalar2
,
int64_t
Dims
>
struct
IndexToScatterOffsets3
{
...
...
cuda/indptr.cuh
View file @
62815576
#pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
// We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`.
...
...
cuda/scatter.cpp
View file @
62815576
#include <torch/
extension
.h>
#include <torch/
script
.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
void
scatter_mul_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
void
scatter_mul_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
int64_t
dim
);
void
scatter_div_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
void
scatter_div_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
int64_t
dim
);
void
scatter_max_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
a
t
::
Tensor
arg
,
int64_t
dim
);
void
scatter_min_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
a
t
::
Tensor
arg
,
int64_t
dim
);
void
index_backward_cuda
(
a
t
::
Tensor
grad
,
a
t
::
Tensor
index
,
at
::
Tensor
arg
,
a
t
::
Tensor
out
,
int64_t
dim
);
void
scatter_max_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
t
orch
::
Tensor
arg
,
int64_t
dim
);
void
scatter_min_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
t
orch
::
Tensor
arg
,
int64_t
dim
);
void
index_backward_cuda
(
t
orch
::
Tensor
grad
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
arg
,
torch
::
Tensor
out
,
int64_t
dim
);
void
scatter_mul
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
void
scatter_mul
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
int64_t
dim
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
...
...
@@ -21,7 +22,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
scatter_mul_cuda
(
src
,
index
,
out
,
dim
);
}
void
scatter_div
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
void
scatter_div
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
int64_t
dim
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
...
...
@@ -29,8 +30,8 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
scatter_div_cuda
(
src
,
index
,
out
,
dim
);
}
void
scatter_max
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
a
t
::
Tensor
arg
,
int64_t
dim
)
{
void
scatter_max
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
t
orch
::
Tensor
arg
,
int64_t
dim
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
out
);
...
...
@@ -38,8 +39,8 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
scatter_max_cuda
(
src
,
index
,
out
,
arg
,
dim
);
}
void
scatter_min
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
a
t
::
Tensor
arg
,
int64_t
dim
)
{
void
scatter_min
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
t
orch
::
Tensor
arg
,
int64_t
dim
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
out
);
...
...
@@ -47,9 +48,8 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
scatter_min_cuda
(
src
,
index
,
out
,
arg
,
dim
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"scatter_mul"
,
&
scatter_mul
,
"Scatter Mul (CUDA)"
);
m
.
def
(
"scatter_div"
,
&
scatter_div
,
"Scatter Div (CUDA)"
);
m
.
def
(
"scatter_max"
,
&
scatter_max
,
"Scatter Max (CUDA)"
);
m
.
def
(
"scatter_min"
,
&
scatter_min
,
"Scatter Min (CUDA)"
);
}
static
auto
registry
=
torch
::
RegisterOperators
(
"torch_scatter_cuda::scatter_mul"
,
&
scatter_mul
)
.
op
(
"torch_scatter_cuda::scatter_div"
,
&
scatter_div
)
.
op
(
"torch_scatter_cuda::scatter_max"
,
&
scatter_max
)
.
op
(
"torch_scatter_cuda::scatter_min"
,
&
scatter_min
);
cuda/scatter_kernel.cu
View file @
62815576
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
#include "atomics.cuh"
#include "index.cuh"
...
...
@@ -9,8 +9,6 @@
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
#define KERNEL_RUN(NAME, DIMS, N, ...) \
[&] { \
auto stream = at::cuda::getCurrentCUDAStream(); \
...
...
@@ -45,8 +43,9 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
}
}
void
scatter_mul_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
void
scatter_mul_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_mul_kernel"
,
[
&
]
{
KERNEL_RUN
(
scatter_mul_kernel
,
index
.
dim
(),
index
.
numel
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
),
...
...
@@ -71,8 +70,9 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
}
}
void
scatter_div_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
void
scatter_div_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_div_kernel"
,
[
&
]
{
KERNEL_RUN
(
scatter_div_kernel
,
index
.
dim
(),
index
.
numel
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
),
...
...
@@ -116,8 +116,9 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
}
}
void
scatter_max_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
void
scatter_max_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
out
,
torch
::
Tensor
arg
,
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_max_kernel"
,
[
&
]
{
auto
src_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
);
...
...
@@ -148,6 +149,7 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void
scatter_min_cuda
(
at
::
Tensor
src
,
at
::
Tensor
index
,
at
::
Tensor
out
,
at
::
Tensor
arg
,
int64_t
dim
)
{
cudaSetDevice
(
src
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"scatter_min_kernel"
,
[
&
]
{
auto
src_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
src
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
);
...
...
cuda/segment.cpp
View file @
62815576
#include <torch/
extension
.h>
#include <torch/
script
.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor")
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
std
::
tuple
<
a
t
::
Tensor
,
a
t
::
optional
<
a
t
::
Tensor
>>
segment_csr_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
indptr
,
a
t
::
optional
<
a
t
::
Tensor
>
out_opt
,
std
::
string
reduce
);
std
::
tuple
<
a
t
::
Tensor
,
a
t
::
optional
<
a
t
::
Tensor
>>
segment_coo_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
std
::
tuple
<
t
orch
::
Tensor
,
t
orch
::
optional
<
t
orch
::
Tensor
>>
segment_csr_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
indptr
,
t
orch
::
optional
<
t
orch
::
Tensor
>
out_opt
,
std
::
string
reduce
);
std
::
tuple
<
t
orch
::
Tensor
,
t
orch
::
optional
<
t
orch
::
Tensor
>>
segment_coo_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
std
::
string
reduce
);
std
::
tuple
<
a
t
::
Tensor
,
a
t
::
optional
<
a
t
::
Tensor
>>
segment_csr
(
a
t
::
Tensor
src
,
a
t
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
std
::
tuple
<
t
orch
::
Tensor
,
t
orch
::
optional
<
t
orch
::
Tensor
>>
segment_csr
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
indptr
,
torch
::
optional
<
torch
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
indptr
);
if
(
out_opt
.
has_value
())
...
...
@@ -19,8 +20,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
return
segment_csr_cuda
(
src
,
indptr
,
out_opt
,
reduce
);
}
std
::
tuple
<
a
t
::
Tensor
,
a
t
::
optional
<
a
t
::
Tensor
>>
segment_coo
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
std
::
tuple
<
t
orch
::
Tensor
,
t
orch
::
optional
<
t
orch
::
Tensor
>>
segment_coo
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
std
::
string
reduce
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
...
...
@@ -28,7 +29,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
return
segment_coo_cuda
(
src
,
index
,
out
,
reduce
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"segment_csr"
,
&
segment_csr
,
"Segment CSR (CUDA)"
);
m
.
def
(
"segment_coo"
,
&
segment_coo
,
"Segment COO (CUDA)"
);
}
static
auto
registry
=
torch
::
RegisterOperators
(
"torch_scatter_cuda::segment_csr"
,
&
segment_csr
)
.
op
(
"torch_scatter_cuda::segment_coo"
,
&
segment_coo
);
cuda/segment_kernel.cu
View file @
62815576
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
#include "atomics.cuh"
#include "compat.cuh"
...
...
@@ -181,9 +181,11 @@ __global__ void segment_csr_broadcast_kernel(
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
optional
<
at
::
Tensor
>>
segment_csr_cuda
(
at
::
Tensor
src
,
at
::
Tensor
indptr
,
at
::
optional
<
at
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
segment_csr_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
indptr
,
torch
::
optional
<
torch
::
Tensor
>
out_opt
,
std
::
string
reduce
)
{
cudaSetDevice
(
src
.
get_device
());
AT_ASSERTM
(
src
.
dim
()
>=
indptr
.
dim
(),
"Input mismatch"
);
...
...
@@ -197,7 +199,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
src
=
src
.
contiguous
();
auto
reduce_dim
=
indptr
.
dim
()
-
1
;
a
t
::
Tensor
out
;
t
orch
::
Tensor
out
;
if
(
out_opt
.
has_value
())
{
out
=
out_opt
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
...
...
@@ -208,13 +210,13 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
}
else
{
sizes
=
src
.
sizes
().
vec
();
sizes
[
reduce_dim
]
=
indptr
.
size
(
reduce_dim
)
-
1
;
out
=
a
t
::
empty
(
sizes
,
src
.
options
());
out
=
t
orch
::
empty
(
sizes
,
src
.
options
());
}
a
t
::
optional
<
a
t
::
Tensor
>
arg_out
=
a
t
::
nullopt
;
t
orch
::
optional
<
t
orch
::
Tensor
>
arg_out
=
t
orch
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
a
t
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
indptr
.
options
());
arg_out
=
t
orch
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
indptr
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
...
...
@@ -382,10 +384,12 @@ __global__ void segment_coo_arg_broadcast_kernel(
}
}
std
::
tuple
<
a
t
::
Tensor
,
a
t
::
optional
<
a
t
::
Tensor
>>
segment_coo_cuda
(
a
t
::
Tensor
src
,
a
t
::
Tensor
index
,
a
t
::
Tensor
out
,
std
::
tuple
<
t
orch
::
Tensor
,
t
orch
::
optional
<
t
orch
::
Tensor
>>
segment_coo_cuda
(
t
orch
::
Tensor
src
,
t
orch
::
Tensor
index
,
t
orch
::
Tensor
out
,
std
::
string
reduce
)
{
cudaSetDevice
(
src
.
get_device
());
AT_ASSERTM
(
src
.
dim
()
>=
index
.
dim
(),
"Input mismatch"
);
// Broadcasting `index` via `expand`.
...
...
@@ -403,10 +407,10 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
if
(
i
!=
reduce_dim
)
AT_ASSERTM
(
src
.
size
(
i
)
==
out
.
size
(
i
),
"Input mismatch"
);
a
t
::
optional
<
a
t
::
Tensor
>
arg_out
=
a
t
::
nullopt
;
t
orch
::
optional
<
t
orch
::
Tensor
>
arg_out
=
t
orch
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
a
t
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
index
.
options
());
arg_out
=
t
orch
::
full_like
(
out
,
src
.
size
(
reduce_dim
),
index
.
options
());
arg_out_data
=
arg_out
.
value
().
DATA_PTR
<
int64_t
>
();
}
...
...
@@ -467,7 +471,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MEAN
)
{
auto
sizes
=
index
.
sizes
().
vec
();
sizes
[
reduce_dim
]
=
out
.
size
(
reduce_dim
);
auto
count
=
a
t
::
zeros
(
sizes
,
out
.
options
());
auto
count
=
t
orch
::
zeros
(
sizes
,
out
.
options
());
AT_DISPATCH_ALL_TYPES
(
out
.
scalar_type
(),
"count_kernel"
,
[
&
]
{
auto
count_data
=
count
.
DATA_PTR
<
scalar_t
>
();
...
...
setup.py
View file @
62815576
...
...
@@ -5,15 +5,16 @@ from setuptools import setup, find_packages
from
sys
import
argv
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
# Windows users: Edit both of these to contain your VS include path, i.e.
# Windows users: Edit both of these to contain your VS include path, i.e.
:
# cxx_extra_compile_args = ['-I{VISUAL_STUDIO_DIR}\\include']
# nvcc_extra_compile_args = [..., '-I{VISUAL_STUDIO_DIR}\\include']
cxx_extra_compile_args
=
[]
nvcc_extra_compile_args
=
[
'-arch=sm_35'
,
'--expt-relaxed-constexpr'
]
# Windows users: Edit both of these to contain your VS library path, i.e.
# Windows users: Edit both of these to contain your VS library path, i.e.
:
# cxx_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}']
# nvcc_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}']
cxx_extra_link_args
=
[]
...
...
@@ -26,7 +27,9 @@ TORCH_MINOR = int(torch.__version__.split('.')[1])
if
(
TORCH_MAJOR
>
1
)
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>
2
):
cxx_extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
nvcc_extra_compile_args
+=
[
'-DVERSION_GE_1_3'
]
cmdclass
=
{
'build_ext'
:
torch
.
utils
.
cpp_extension
.
BuildExtension
}
cmdclass
=
{
'build_ext'
:
BuildExtension
.
with_options
(
no_python_abi_suffix
=
True
)
}
ext_modules
=
[]
exts
=
[
e
.
split
(
osp
.
sep
)[
-
1
][:
-
4
]
for
e
in
glob
(
osp
.
join
(
'cpu'
,
'*.cpp'
))]
...
...
torch_scatter/__init__.py
View file @
62815576
import
torch
torch
.
ops
.
load_library
(
'torch_scatter/scatter_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_scatter/segment_cpu.so'
)
torch
.
ops
.
load_library
(
'torch_scatter/gather_cpu.so'
)
try
:
torch
.
ops
.
load_library
(
'torch_scatter/scatter_cuda.so'
)
torch
.
ops
.
load_library
(
'torch_scatter/segment_cuda.so'
)
torch
.
ops
.
load_library
(
'torch_scatter/gather_cuda.so'
)
except
OSError
as
e
:
if
torch
.
cuda
.
is_available
():
raise
e
from
.add
import
scatter_add
from
.sub
import
scatter_sub
from
.mul
import
scatter_mul
...
...
torch_scatter/div.py
View file @
62815576
from
torch.autograd
import
Function
from
torch_scatter.utils.ext
import
get_func
import
torch
from
torch_scatter.utils.gen
import
gen
class
ScatterDiv
(
Function
):
class
ScatterDiv
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
out
,
src
,
index
,
dim
):
func
=
get_func
(
'scatter_div'
,
src
)
func
(
src
,
index
,
out
,
dim
)
if
src
.
is_cuda
:
torch
.
ops
.
torch_scatter_cuda
.
scatter_div
(
src
,
index
,
out
,
dim
)
else
:
torch
.
ops
.
torch_scatter_cpu
.
scatter_div
(
src
,
index
,
out
,
dim
)
ctx
.
mark_dirty
(
out
)
ctx
.
save_for_backward
(
out
,
src
,
index
)
...
...
torch_scatter/gather.py
View file @
62815576
import
torch
from
torch_scatter
import
segment_cpu
,
gather_cpu
if
torch
.
cuda
.
is_available
():
from
torch_scatter
import
gather_cuda
,
segment_cuda
def
gat
(
is_cuda
):
return
gather_cuda
if
is_cuda
else
gather_cpu
def
seg
(
is_cuda
):
return
segment_cuda
if
is_cuda
else
segment_cpu
class
GatherCOO
(
torch
.
autograd
.
Function
):
@
staticmethod
...
...
@@ -22,7 +9,10 @@ class GatherCOO(torch.autograd.Function):
ctx
.
src_size
=
list
(
src
.
size
())
ctx
.
save_for_backward
(
index
)
return
gat
(
src
.
is_cuda
).
gather_coo
(
src
,
index
,
out
)
if
src
.
is_cuda
:
return
torch
.
ops
.
torch_scatter_cuda
.
gather_coo
(
src
,
index
,
out
)
else
:
return
torch
.
ops
.
torch_scatter_cpu
.
gather_coo
(
src
,
index
,
out
)
@
staticmethod
def
backward
(
ctx
,
grad_out
):
...
...
@@ -30,7 +20,11 @@ class GatherCOO(torch.autograd.Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
grad_src
,
_
=
seg
(
grad_out
.
is_cuda
).
segment_coo
(
if
grad_out
.
is_cuda
:
grad_src
,
_
=
torch
.
ops
.
torch_scatter_cuda
.
segment_coo
(
grad_out
,
index
,
grad_out
.
new_zeros
(
src_size
),
'sum'
)
else
:
grad_src
,
_
=
torch
.
ops
.
torch_scatter_cpu
.
segment_coo
(
grad_out
,
index
,
grad_out
.
new_zeros
(
src_size
),
'sum'
)
return
grad_src
,
None
,
None
...
...
@@ -44,7 +38,10 @@ class GatherCSR(torch.autograd.Function):
ctx
.
src_size
=
list
(
src
.
size
())
ctx
.
save_for_backward
(
indptr
)
return
gat
(
src
.
is_cuda
).
gather_csr
(
src
,
indptr
,
out
)
if
src
.
is_cuda
:
return
torch
.
ops
.
torch_scatter_cuda
.
gather_csr
(
src
,
indptr
,
out
)
else
:
return
torch
.
ops
.
torch_scatter_cpu
.
gather_csr
(
src
,
indptr
,
out
)
@
staticmethod
def
backward
(
ctx
,
grad_out
):
...
...
@@ -52,7 +49,11 @@ class GatherCSR(torch.autograd.Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
0
]:
grad_src
,
_
=
seg
(
grad_out
.
is_cuda
).
segment_csr
(
if
grad_out
.
is_cuda
:
grad_src
,
_
=
torch
.
ops
.
torch_scatter_cuda
.
segment_csr
(
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
),
'sum'
)
else
:
grad_src
,
_
=
torch
.
ops
.
torch_scatter_cpu
.
segment_csr
(
grad_out
,
indptr
,
grad_out
.
new_empty
(
src_size
),
'sum'
)
return
grad_src
,
None
,
None
...
...
torch_scatter/max.py
View file @
62815576
import
torch
from
torch.autograd
import
Function
from
torch_scatter.utils.ext
import
get_func
from
torch_scatter.utils.gen
import
gen
class
ScatterMax
(
Function
):
class
ScatterMax
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
out
,
src
,
index
,
dim
):
arg
=
index
.
new_full
(
out
.
size
(),
-
1
)
func
=
get_func
(
'scatter_max'
,
src
)
func
(
src
,
index
,
out
,
arg
,
dim
)
if
src
.
is_cuda
:
torch
.
ops
.
torch_scatter_cuda
.
scatter_max
(
src
,
index
,
out
,
arg
,
dim
)
else
:
torch
.
ops
.
torch_scatter_cpu
.
scatter_max
(
src
,
index
,
out
,
arg
,
dim
)
ctx
.
mark_dirty
(
out
)
ctx
.
dim
=
dim
...
...
torch_scatter/min.py
View file @
62815576
import
torch
from
torch.autograd
import
Function
from
torch_scatter.utils.ext
import
get_func
from
torch_scatter.utils.gen
import
gen
class
ScatterMin
(
Function
):
class
ScatterMin
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
out
,
src
,
index
,
dim
):
arg
=
index
.
new_full
(
out
.
size
(),
-
1
)
func
=
get_func
(
'scatter_min'
,
src
)
func
(
src
,
index
,
out
,
arg
,
dim
)
if
src
.
is_cuda
:
torch
.
ops
.
torch_scatter_cuda
.
scatter_min
(
src
,
index
,
out
,
arg
,
dim
)
else
:
torch
.
ops
.
torch_scatter_cpu
.
scatter_min
(
src
,
index
,
out
,
arg
,
dim
)
ctx
.
mark_dirty
(
out
)
ctx
.
dim
=
dim
...
...
torch_scatter/mul.py
View file @
62815576
from
torch.autograd
import
Function
from
torch_scatter.utils.ext
import
get_func
import
torch
from
torch_scatter.utils.gen
import
gen
class
ScatterMul
(
Function
):
class
ScatterMul
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
out
,
src
,
index
,
dim
):
func
=
get_func
(
'scatter_mul'
,
src
)
func
(
src
,
index
,
out
,
dim
)
if
src
.
is_cuda
:
torch
.
ops
.
torch_scatter_cuda
.
scatter_mul
(
src
,
index
,
out
,
dim
)
else
:
torch
.
ops
.
torch_scatter_cpu
.
scatter_mul
(
src
,
index
,
out
,
dim
)
ctx
.
mark_dirty
(
out
)
ctx
.
save_for_backward
(
out
,
src
,
index
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment