Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
64772d75
"tests/vscode:/vscode.git/clone" did not exist on "f33b89bafb0c63eea5c94b1b70f388fbdfc0bcce"
Commit
64772d75
authored
Jan 30, 2020
by
rusty1s
Browse files
segment coo done
parent
d0f5005f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
394 additions
and
39 deletions
+394
-39
csrc/cuda/segment_coo_cuda.cu
csrc/cuda/segment_coo_cuda.cu
+358
-2
csrc/cuda/segment_csr_cuda.cu
csrc/cuda/segment_csr_cuda.cu
+4
-6
setup.py
setup.py
+21
-19
test/test_gather.py
test/test_gather.py
+0
-2
test/test_segment.py
test/test_segment.py
+0
-1
torch_scatter/__init__.py
torch_scatter/__init__.py
+3
-9
torch_scatter/segment_coo.py
torch_scatter/segment_coo.py
+4
-0
torch_scatter/segment_csr.py
torch_scatter/segment_csr.py
+4
-0
No files found.
csrc/cuda/segment_coo_cuda.cu
View file @
64772d75
#include "segment_coo_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 256
#define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS
#define FULL_MASK 0xffffffff
template
<
typename
scalar_t
,
ReductionType
REDUCE
,
bool
HAS_VAL
>
__global__
void
segment_coo_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
size_t
E
,
size_t
N
)
{
// Each thread processes exactly one entry. Within a warp, we perform a
// parallel reduction across equal indices, and write the intermediate
// result via atomics.
int
row_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
lane_idx
=
row_idx
&
(
32
-
1
);
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
if
(
row_idx
<
E
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int64_t
idx
=
index_info
.
data
[
offset
],
next_idx
;
int
out_idx
=
(
row_idx
/
D
)
*
N
+
idx
;
scalar_t
val
=
HAS_VAL
?
src_data
[
row_idx
]
:
(
scalar_t
)
1
,
tmp
;
#pragma unroll
for
(
int
i
=
1
;
i
<
32
;
i
*=
2
)
{
// Parallel reduction inside a single warp.
tmp
=
__shfl_up_sync
(
FULL_MASK
,
val
,
i
);
next_idx
=
__shfl_up_sync
(
FULL_MASK
,
idx
,
i
);
if
(
lane_idx
>=
i
&&
row_idx
/
D
==
(
row_idx
-
i
)
/
D
)
{
assert
(
idx
>=
next_idx
);
if
(
idx
==
next_idx
)
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
tmp
);
}
}
next_idx
=
__shfl_down_sync
(
FULL_MASK
,
idx
,
1
);
if
(
lane_idx
==
32
-
1
||
row_idx
/
D
!=
(
row_idx
+
1
)
/
D
||
idx
!=
next_idx
)
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
out_idx
,
val
);
}
}
template
<
typename
scalar_t
>
__global__
void
segment_coo_arg_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
,
size_t
N
)
{
int
row_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
if
(
row_idx
<
E
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int64_t
idx
=
index_info
.
data
[
offset
];
int
out_idx
=
(
row_idx
/
D
)
*
N
+
idx
;
scalar_t
val
=
__ldg
(
out_data
+
out_idx
);
if
(
src_data
[
row_idx
]
==
val
)
arg_out_data
[
out_idx
]
=
row_idx
%
D
;
}
}
template
<
typename
scalar_t
,
ReductionType
REDUCE
,
int
TB
>
__global__
void
segment_coo_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
size_t
E
,
size_t
K
,
size_t
N
)
{
// Each thread processes a single column and `TB` index entries. Coalesced
// read and write is performed in column-major order. The intermediate
// results are written via atomics.
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
int
E_1
=
E
/
D
;
int
E_2
=
D
+
TB
-
(
D
%
TB
);
int
row_idx
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
int
col_idx
=
blockIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
int
dim_start
=
(
row_idx
*
TB
)
/
E_2
;
int
row_start
=
(
row_idx
*
TB
)
%
E_2
;
if
(
dim_start
<
E_1
&&
col_idx
<
K
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
dim_start
*
D
+
row_start
,
index_info
);
int
idx1
=
__ldg
(
index_info
.
data
+
offset
),
idx2
;
scalar_t
val
=
src_data
[
K
*
(
dim_start
*
D
+
row_start
)
+
col_idx
];
#pragma unroll
for
(
int
i
=
1
;
i
<
TB
;
i
++
)
{
if
(
row_start
+
i
>=
D
)
break
;
idx2
=
__ldg
(
index_info
.
data
+
offset
+
i
*
index_info
.
strides
[
index_info
.
dims
-
1
]);
assert
(
idx1
<=
idx2
);
if
(
idx1
==
idx2
)
{
Reducer
<
scalar_t
,
REDUCE
>::
update
(
&
val
,
src_data
[
K
*
(
dim_start
*
D
+
row_start
+
i
)
+
col_idx
]);
}
else
{
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
(
dim_start
*
N
+
idx1
)
*
K
+
col_idx
,
val
);
val
=
src_data
[
K
*
(
dim_start
*
D
+
row_start
+
i
)
+
col_idx
];
}
idx1
=
idx2
;
}
Reducer
<
scalar_t
,
REDUCE
>::
atomic_write
(
out_data
+
(
dim_start
*
N
+
idx1
)
*
K
+
col_idx
,
val
);
}
}
template
<
typename
scalar_t
>
__global__
void
segment_coo_arg_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
int64_t
*
arg_out_data
,
size_t
E
,
size_t
K
,
size_t
N
)
{
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row_idx
=
thread_idx
/
K
;
int
col_idx
=
thread_idx
%
K
;
int
D
=
index_info
.
sizes
[
index_info
.
dims
-
1
];
if
(
row_idx
<
E
&&
col_idx
<
K
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
idx
=
__ldg
(
index_info
.
data
+
offset
);
int
out_idx
=
((
row_idx
/
D
)
*
N
+
idx
)
*
K
+
col_idx
;
scalar_t
val
=
__ldg
(
out_data
+
out_idx
);
if
(
src_data
[
thread_idx
]
==
val
)
arg_out_data
[
out_idx
]
=
row_idx
%
D
;
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>>
segment_coo_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
,
torch
::
optional
<
int64_t
>
dim_size
,
std
::
string
reduce
)
{
return
std
::
make_tuple
(
src
,
optional_out
);
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
if
(
optional_out
.
has_value
())
CHECK_CUDA
(
optional_out
.
value
());
cudaSetDevice
(
src
.
get_device
());
CHECK_INPUT
(
src
.
dim
()
>=
index
.
dim
());
auto
sizes
=
index
.
sizes
().
vec
();
for
(
int
i
=
0
;
i
<
index
.
dim
();
i
++
)
{
sizes
[
i
]
=
src
.
size
(
i
);
}
index
=
index
.
expand
(
sizes
);
auto
dim
=
index
.
dim
()
-
1
;
src
=
src
.
contiguous
();
torch
::
Tensor
out
;
if
(
optional_out
.
has_value
())
{
out
=
optional_out
.
value
().
contiguous
();
for
(
int
i
=
0
;
i
<
out
.
dim
();
i
++
)
if
(
i
!=
dim
)
CHECK_INPUT
(
src
.
size
(
i
)
==
out
.
size
(
i
));
}
else
{
sizes
=
src
.
sizes
().
vec
();
if
(
dim_size
.
has_value
())
sizes
[
dim
]
=
dim_size
.
value
();
else
{
auto
d_size
=
index
.
max
().
data_ptr
<
int64_t
>
();
auto
h_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
h_size
,
d_size
,
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
sizes
[
dim
]
=
1
+
*
h_size
;
}
out
=
torch
::
zeros
(
sizes
,
src
.
options
());
}
torch
::
optional
<
torch
::
Tensor
>
arg_out
=
torch
::
nullopt
;
int64_t
*
arg_out_data
=
nullptr
;
if
(
reduce2REDUCE
.
at
(
reduce
)
==
MIN
||
reduce2REDUCE
.
at
(
reduce
)
==
MAX
)
{
arg_out
=
torch
::
full_like
(
out
,
src
.
size
(
dim
),
index
.
options
());
arg_out_data
=
arg_out
.
value
().
data_ptr
<
int64_t
>
();
}
auto
E
=
index
.
numel
();
auto
E_2
=
index
.
size
(
dim
);
auto
E_1
=
index
.
numel
()
/
E_2
;
auto
K
=
src
.
numel
()
/
E
;
auto
N
=
out
.
size
(
dim
);
auto
avg_len
=
(
float
)
E_2
/
(
float
)
N
;
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
index
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"segment_coo_kernel"
,
[
&
]
{
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
AT_DISPATCH_REDUCTION_TYPES
(
reduce
,
[
&
]
{
if
(
!
optional_out
.
has_value
())
out
.
fill_
(
Reducer
<
scalar_t
,
REDUCE
>::
init
());
if
(
K
==
1
)
segment_coo_kernel
<
scalar_t
,
REDUCE
,
true
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
N
);
else
if
(
avg_len
<=
8
)
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
4
>
<<<
dim3
((
E_1
*
((
E_2
+
3
)
/
4
)
+
7
)
/
8
,
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
else
if
(
avg_len
<=
16
)
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
8
>
<<<
dim3
((
E_1
*
((
E_2
+
7
)
/
8
)
+
7
)
/
8
,
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
else
if
(
avg_len
<=
32
)
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
16
>
<<<
dim3
((
E_1
*
((
E_2
+
15
)
/
16
)
+
7
)
/
8
,
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
else
segment_coo_broadcast_kernel
<
scalar_t
,
REDUCE
,
32
>
<<<
dim3
((
E_1
*
((
E_2
+
31
)
/
32
)
+
7
)
/
8
,
(
K
+
31
)
/
32
),
dim3
(
32
,
8
),
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
if
(
!
optional_out
.
has_value
()
&&
(
REDUCE
==
MIN
||
REDUCE
==
MAX
))
out
.
masked_fill_
(
out
==
Reducer
<
scalar_t
,
REDUCE
>::
init
(),
(
scalar_t
)
0
);
if
(
REDUCE
==
MIN
||
REDUCE
==
MAX
)
{
if
(
K
==
1
)
segment_coo_arg_kernel
<
scalar_t
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
N
);
else
segment_coo_arg_broadcast_kernel
<
scalar_t
>
<<<
BLOCKS
(
1
,
E
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
arg_out_data
,
E
,
K
,
N
);
}
if
(
REDUCE
==
MEAN
)
{
auto
sizes
=
index
.
sizes
().
vec
();
sizes
[
dim
]
=
out
.
size
(
dim
);
auto
count
=
torch
::
zeros
(
sizes
,
out
.
options
());
auto
count_data
=
count
.
data_ptr
<
scalar_t
>
();
segment_coo_kernel
<
scalar_t
,
SUM
,
false
>
<<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
nullptr
,
index_info
,
count_data
,
E
,
N
);
arg_out
=
count
;
for
(
int
i
=
dim
+
1
;
i
<
out
.
dim
();
i
++
)
count
=
count
.
unsqueeze
(
-
1
);
out
.
div_
(
count
.
clamp_
(
1
));
}
});
});
return
std
::
make_tuple
(
out
,
arg_out
);
}
template
<
typename
scalar_t
>
__global__
void
gather_coo_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
size_t
E
,
size_t
N
)
{
int
row_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
row_idx
<
E
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
row
=
index_info
.
data
[
offset
];
offset
=
(
row_idx
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
;
scalar_t
val
=
__ldg
(
src_data
+
offset
+
row
);
out_data
[
row_idx
]
=
val
;
}
}
template
<
typename
scalar_t
>
__global__
void
gather_coo_broadcast_kernel
(
const
scalar_t
*
src_data
,
const
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int
>
index_info
,
scalar_t
*
out_data
,
size_t
E
,
size_t
K
,
size_t
N
)
{
int
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row_idx
=
thread_idx
/
K
;
int
col_idx
=
thread_idx
%
K
;
if
(
thread_idx
<
E
*
K
)
{
int
offset
=
at
::
cuda
::
detail
::
IndexToOffset
<
int64_t
,
int
,
-
1
>::
get
(
row_idx
,
index_info
);
int
row
=
index_info
.
data
[
offset
];
offset
=
(
row_idx
/
index_info
.
sizes
[
index_info
.
dims
-
1
])
*
N
*
K
;
scalar_t
val
=
__ldg
(
src_data
+
offset
+
K
*
row
+
col_idx
);
out_data
[
thread_idx
]
=
val
;
}
}
torch
::
Tensor
gather_coo_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
optional
<
torch
::
Tensor
>
optional_out
)
{
return
src
;
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
if
(
optional_out
.
has_value
())
CHECK_CUDA
(
optional_out
.
value
());
cudaSetDevice
(
src
.
get_device
());
CHECK_INPUT
(
src
.
dim
()
>=
index
.
dim
());
auto
sizes
=
index
.
sizes
().
vec
();
for
(
auto
i
=
0
;
i
<
index
.
dim
()
-
1
;
i
++
)
sizes
[
i
]
=
src
.
size
(
i
);
index
=
index
.
expand
(
sizes
);
auto
dim
=
index
.
dim
()
-
1
;
src
=
src
.
contiguous
();
torch
::
Tensor
out
;
if
(
optional_out
.
has_value
())
{
out
=
optional_out
.
value
().
contiguous
();
for
(
auto
i
=
0
;
i
<
src
.
dim
();
i
++
)
if
(
i
!=
dim
)
CHECK_INPUT
(
src
.
size
(
i
)
==
out
.
size
(
i
));
CHECK_INPUT
(
index
.
size
(
dim
)
==
out
.
size
(
dim
));
}
else
{
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
dim
]
=
index
.
size
(
dim
);
out
=
torch
::
empty
(
sizes
,
src
.
options
());
}
auto
E
=
index
.
numel
();
auto
K
=
out
.
numel
()
/
E
;
auto
N
=
src
.
size
(
dim
);
auto
index_info
=
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int
>
(
index
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"gather_coo_kernel"
,
[
&
]
{
auto
src_data
=
src
.
data_ptr
<
scalar_t
>
();
auto
out_data
=
out
.
data_ptr
<
scalar_t
>
();
if
(
K
==
1
)
gather_coo_kernel
<
scalar_t
><<<
BLOCKS
(
1
,
E
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
N
);
else
gather_coo_broadcast_kernel
<
scalar_t
>
<<<
BLOCKS
(
1
,
E
*
K
),
THREADS
,
0
,
stream
>>>
(
src_data
,
index_info
,
out_data
,
E
,
K
,
N
);
});
return
out
;
}
csrc/cuda/segment_csr_cuda.cu
View file @
64772d75
...
...
@@ -237,13 +237,11 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
if
(
i
!=
dim
)
CHECK_INPUT
(
src
.
size
(
i
)
==
out
.
size
(
i
));
}
else
{
auto
d_gather_size
=
indptr
.
flatten
()[
-
1
].
data_ptr
<
int64_t
>
();
auto
h_gather_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
h_gather_size
,
d_gather_size
,
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
auto
d_size
=
indptr
.
flatten
()[
-
1
].
data_ptr
<
int64_t
>
();
auto
h_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
h_size
,
d_size
,
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
auto
sizes
=
src
.
sizes
().
vec
();
sizes
[
dim
]
=
*
h_
gather_
size
;
sizes
[
dim
]
=
*
h_size
;
out
=
torch
::
empty
(
sizes
,
src
.
options
());
}
...
...
setup.py
View file @
64772d75
...
...
@@ -8,31 +8,22 @@ import torch
from
torch.utils.cpp_extension
import
BuildExtension
from
torch.utils.cpp_extension
import
CppExtension
,
CUDAExtension
,
CUDA_HOME
WITH_CUDA
=
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
WITH_CUDA
=
WITH_CUDA
or
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
def
get_extensions
():
this_dir
=
osp
.
dirname
(
osp
.
abspath
(
__file__
))
extensions_dir
=
osp
.
join
(
this_dir
,
'csrc'
)
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
cpu_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'cpu'
,
'*.cpp'
))
cuda_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'cuda'
,
'*.cu'
))
def
get_extensions
():
Extension
=
CppExtension
sources
=
main_files
+
cpu_files
define_macros
=
[]
extra_compile_args
=
{
'cxx'
:
[],
'nvcc'
:
[]}
# Windows users: Edit both of these to contain your VS include path, i.e.:
# extra_compile_args['cxx'] += ['-I{VISUAL_STUDIO_DIR}\\include']
# extra_compile_args['nvcc'] += ['-I{VISUAL_STUDIO_DIR}\\include']
if
(
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
)
or
os
.
getenv
(
'FORCE_CUDA'
,
'0'
)
==
'1'
:
if
WITH_CUDA
:
Extension
=
CUDAExtension
sources
+=
cuda_files
define_macros
+=
[(
'WITH_CUDA'
,
None
)]
nvcc_flags
=
os
.
getenv
(
'NVCC_FLAGS'
,
''
)
nvcc_flags
=
[]
if
nvcc_flags
==
''
else
nvcc_flags
.
split
(
' '
)
nvcc_flags
+=
[
'-arch=sm_35'
,
'--expt-relaxed-constexpr'
]
...
...
@@ -42,15 +33,26 @@ def get_extensions():
if
sys
.
platform
==
'win32'
:
extra_compile_args
[
'cxx'
]
+=
[
'/MP'
]
return
[
Extension
(
'torch_scatter._C'
,
extensions_dir
=
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'csrc'
)
main_files
=
glob
.
glob
(
osp
.
join
(
extensions_dir
,
'*.cpp'
))
extensions
=
[]
for
main
in
main_files
:
name
=
main
.
split
(
os
.
sep
)[
-
1
][:
-
4
]
sources
=
[
main
,
osp
.
join
(
extensions_dir
,
'cpu'
,
f
'
{
name
}
_cpu.cpp'
)]
if
WITH_CUDA
:
sources
+=
[
osp
.
join
(
extensions_dir
,
'cuda'
,
f
'
{
name
}
_cuda.cu'
)]
extension
=
Extension
(
f
'torch_scatter._
{
name
}
'
,
sources
,
include_dirs
=
[
extensions_dir
],
define_macros
=
define_macros
,
extra_compile_args
=
extra_compile_args
,
)
]
extensions
+=
[
extension
]
return
extensions
install_requires
=
[]
...
...
@@ -59,7 +61,7 @@ tests_require = ['pytest', 'pytest-cov']
setup
(
name
=
'torch_scatter'
,
version
=
'
1.5
.0'
,
version
=
'
2.0
.0'
,
author
=
'Matthias Fey'
,
author_email
=
'matthias.fey@tu-dortmund.de'
,
url
=
'https://github.com/rusty1s/pytorch_scatter'
,
...
...
test/test_gather.py
View file @
64772d75
...
...
@@ -7,8 +7,6 @@ from torch_scatter import gather_csr, gather_coo
from
.utils
import
tensor
,
dtypes
,
devices
devices
=
[
'cpu'
]
tests
=
[
{
'src'
:
[
1
,
2
,
3
,
4
],
...
...
test/test_segment.py
View file @
64772d75
...
...
@@ -8,7 +8,6 @@ import torch_scatter
from
.utils
import
tensor
,
dtypes
,
devices
reductions
=
[
'sum'
,
'mean'
,
'min'
,
'max'
]
devices
=
[
'cpu'
]
tests
=
[
{
...
...
torch_scatter/__init__.py
View file @
64772d75
import
os.path
as
osp
import
torch
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_C.so'
))
from
.segment_csr
import
(
segment_sum_csr
,
segment_add_csr
,
segment_mean_csr
,
segment_min_csr
,
segment_max_csr
,
segment_csr
,
gather_csr
)
# noqa
gather_csr
)
from
.segment_coo
import
(
segment_sum_coo
,
segment_add_coo
,
segment_mean_coo
,
segment_min_coo
,
segment_max_coo
,
segment_coo
,
gather_coo
)
# noqa
gather_coo
)
__version__
=
'
1.5
.0'
__version__
=
'
2.0
.0'
__all__
=
[
'segment_sum_csr'
,
...
...
torch_scatter/segment_coo.py
View file @
64772d75
import
os.path
as
osp
from
typing
import
Optional
,
Tuple
import
torch
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_segment_coo.so'
))
@
torch
.
jit
.
script
def
segment_sum_coo
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
...
...
torch_scatter/segment_csr.py
View file @
64772d75
import
os.path
as
osp
from
typing
import
Optional
,
Tuple
import
torch
torch
.
ops
.
load_library
(
osp
.
join
(
osp
.
dirname
(
osp
.
abspath
(
__file__
)),
'_segment_csr.so'
))
@
torch
.
jit
.
script
def
segment_sum_csr
(
src
:
torch
.
Tensor
,
indptr
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment