Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
15c56351
Unverified
Commit
15c56351
authored
Apr 22, 2022
by
Matthias Fey
Committed by
GitHub
Apr 22, 2022
Browse files
Version up (#224)
* version up * formatting * fix * reset * revert
parent
79535f30
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
47 additions
and
58 deletions
+47
-58
CMakeLists.txt
CMakeLists.txt
+1
-1
conda/pytorch-sparse/meta.yaml
conda/pytorch-sparse/meta.yaml
+1
-1
csrc/cpu/neighbor_sample_cpu.cpp
csrc/cpu/neighbor_sample_cpu.cpp
+38
-54
setup.cfg
setup.cfg
+5
-0
setup.py
setup.py
+1
-1
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-1
No files found.
CMakeLists.txt
View file @
15c56351
cmake_minimum_required
(
VERSION 3.0
)
project
(
torchsparse
)
set
(
CMAKE_CXX_STANDARD 14
)
set
(
TORCHSPARSE_VERSION 0.
6.13
)
set
(
TORCHSPARSE_VERSION 0.
7.0
)
option
(
WITH_CUDA
"Enable CUDA support"
OFF
)
option
(
WITH_PYTHON
"Link to Python when building"
ON
)
...
...
conda/pytorch-sparse/meta.yaml
View file @
15c56351
package
:
name
:
pytorch-sparse
version
:
0.
6.13
version
:
0.
7.0
source
:
path
:
../..
...
...
csrc/cpu/neighbor_sample_cpu.cpp
View file @
15c56351
...
...
@@ -114,35 +114,31 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
from_vector
<
int64_t
>
(
cols
),
from_vector
<
int64_t
>
(
edges
));
}
bool
satisfy_time_constraint
(
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
node_time_dict
,
const
std
::
string
&
src_node_type
,
const
int64_t
&
dst_time
,
const
int64_t
&
s
ampled
_node
)
{
bool
satisfy_time_constraint
(
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
node_time_dict
,
const
node_t
&
src_node_type
,
const
int64_t
&
dst_time
,
const
int64_t
&
s
rc
_node
)
{
// whether src -> dst obeys the time constraint
try
{
const
auto
*
src_time
=
node_time_dict
.
at
(
src_node_type
).
data_ptr
<
int64_t
>
();
return
dst_time
<
src_time
[
sampled_node
];
}
catch
(
int
err
)
{
auto
src_time
=
node_time_dict
.
at
(
src_node_type
).
data_ptr
<
int64_t
>
();
return
dst_time
<
src_time
[
src_node
];
}
catch
(
int
err
)
{
// if the node type does not have timestamp, fall back to normal sampling
return
true
;
}
}
template
<
bool
replace
,
bool
directed
,
bool
temporal
>
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hetero_sample
(
const
vector
<
node_t
>
&
node_types
,
const
vector
<
edge_t
>
&
edge_types
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
colptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
rel_t
,
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
int64_t
num_hops
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
node_time_dict
)
{
//bool temporal = (!node_time_dict.empty());
const
vector
<
edge_t
>
&
edge_types
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
colptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
rel_t
,
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
int64_t
num_hops
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
node_time_dict
)
{
// Create a mapping to convert single string relations to edge type triplets:
unordered_map
<
rel_t
,
edge_t
>
to_edge_type
;
for
(
const
auto
&
k
:
edge_types
)
...
...
@@ -174,11 +170,12 @@ hetero_sample(const vector<node_t> &node_types,
const
torch
::
Tensor
&
input_node
=
kv
.
value
();
const
auto
*
input_node_data
=
input_node
.
data_ptr
<
int64_t
>
();
// dummy value. will be reset to root time if is_temporal==true
auto
*
node_time_data
=
input_node
.
data_ptr
<
int64_t
>
()
;
int64_t
*
node_time_data
;
// root_time[i] stores the timestamp of the computation tree root
// of the node samples[i]
if
(
temporal
)
{
node_time_data
=
node_time_dict
.
at
(
node_type
).
data_ptr
<
int64_t
>
();
torch
::
Tensor
node_time
=
node_time_dict
.
at
(
node_type
);
node_time_data
=
node_time
.
data_ptr
<
int64_t
>
();
}
auto
&
samples
=
samples_dict
.
at
(
node_type
);
...
...
@@ -220,7 +217,7 @@ hetero_sample(const vector<node_t> &node_types,
const
auto
&
begin
=
slice_dict
.
at
(
dst_node_type
).
first
;
const
auto
&
end
=
slice_dict
.
at
(
dst_node_type
).
second
;
if
(
begin
==
end
){
if
(
begin
==
end
)
{
continue
;
}
// for temporal sampling, sampled src node cannot have timestamp greater
...
...
@@ -370,22 +367,17 @@ hetero_sample(const vector<node_t> &node_types,
template
<
bool
replace
,
bool
directed
>
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hetero_sample_random
(
const
vector
<
node_t
>
&
node_types
,
const
vector
<
edge_t
>
&
edge_types
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
colptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
rel_t
,
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
int64_t
num_hops
)
{
hetero_sample_random
(
const
vector
<
node_t
>
&
node_types
,
const
vector
<
edge_t
>
&
edge_types
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
colptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
rel_t
,
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
int64_t
num_hops
)
{
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
empty_dict
;
return
hetero_sample
<
replace
,
directed
,
false
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
,
empty_dict
);
return
hetero_sample
<
replace
,
directed
,
false
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
,
empty_dict
);
}
}
// namespace
...
...
@@ -418,24 +410,20 @@ hetero_neighbor_sample_cpu(
const
int64_t
num_hops
,
const
bool
replace
,
const
bool
directed
)
{
if
(
replace
&&
directed
)
{
return
hetero_sample_random
<
true
,
true
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
);
return
hetero_sample_random
<
true
,
true
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
);
}
else
if
(
replace
&&
!
directed
)
{
return
hetero_sample_random
<
true
,
false
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
);
}
else
if
(
!
replace
&&
directed
)
{
return
hetero_sample_random
<
false
,
true
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
);
}
else
{
return
hetero_sample_random
<
false
,
false
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
);
}
}
...
...
@@ -453,23 +441,19 @@ hetero_neighbor_temporal_sample_cpu(
if
(
replace
&&
directed
)
{
return
hetero_sample
<
true
,
true
,
true
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
,
node_time_dict
);
}
else
if
(
replace
&&
!
directed
)
{
return
hetero_sample
<
true
,
false
,
true
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
,
node_time_dict
);
}
else
if
(
!
replace
&&
directed
)
{
return
hetero_sample
<
false
,
true
,
true
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
,
node_time_dict
);
}
else
{
return
hetero_sample
<
false
,
false
,
true
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
,
node_time_dict
);
}
}
setup.cfg
View file @
15c56351
...
...
@@ -17,3 +17,8 @@ test = pytest
[tool:pytest]
addopts = --capture=no
[isort]
multi_line_output=3
include_trailing_comma = True
skip=.gitignore,__init__.py
setup.py
View file @
15c56351
...
...
@@ -11,7 +11,7 @@ from torch.__config__ import parallel_info
from
torch.utils.cpp_extension
import
(
CUDA_HOME
,
BuildExtension
,
CppExtension
,
CUDAExtension
)
__version__
=
'0.
6.13
'
__version__
=
'0.
7.0
'
URL
=
'https://github.com/rusty1s/pytorch_sparse'
WITH_CUDA
=
torch
.
cuda
.
is_available
()
and
CUDA_HOME
is
not
None
...
...
torch_sparse/__init__.py
View file @
15c56351
...
...
@@ -3,7 +3,7 @@ import os.path as osp
import
torch
__version__
=
'0.
6.13
'
__version__
=
'0.
7.0
'
for
library
in
[
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
,
'_metis'
,
'_rw'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment