Unverified Commit 15c56351 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Version up (#224)

* version up

* formatting

* fix

* reset

* revert
parent 79535f30
cmake_minimum_required(VERSION 3.0) cmake_minimum_required(VERSION 3.0)
project(torchsparse) project(torchsparse)
set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD 14)
set(TORCHSPARSE_VERSION 0.6.13) set(TORCHSPARSE_VERSION 0.7.0)
option(WITH_CUDA "Enable CUDA support" OFF) option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_PYTHON "Link to Python when building" ON) option(WITH_PYTHON "Link to Python when building" ON)
......
package: package:
name: pytorch-sparse name: pytorch-sparse
version: 0.6.13 version: 0.7.0
source: source:
path: ../.. path: ../..
......
...@@ -114,35 +114,31 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row, ...@@ -114,35 +114,31 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
from_vector<int64_t>(cols), from_vector<int64_t>(edges)); from_vector<int64_t>(cols), from_vector<int64_t>(edges));
} }
bool satisfy_time_constraint(const c10::Dict<node_t, torch::Tensor> &node_time_dict, bool satisfy_time_constraint(
const std::string &src_node_type, const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t &dst_time, const node_t &src_node_type, const int64_t &dst_time,
const int64_t &sampled_node) { const int64_t &src_node) {
// whether src -> dst obeys the time constraint // whether src -> dst obeys the time constraint
try { try {
const auto *src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>(); auto src_time = node_time_dict.at(src_node_type).data_ptr<int64_t>();
return dst_time < src_time[sampled_node]; return dst_time < src_time[src_node];
} } catch (int err) {
catch (int err) {
// if the node type does not have timestamp, fall back to normal sampling // if the node type does not have timestamp, fall back to normal sampling
return true; return true;
} }
} }
template <bool replace, bool directed, bool temporal> template <bool replace, bool directed, bool temporal>
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>, tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>> c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample(const vector<node_t> &node_types, hetero_sample(const vector<node_t> &node_types,
const vector<edge_t> &edge_types, const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict, const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict, const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict, const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict, const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const int64_t num_hops,
const c10::Dict<node_t, torch::Tensor> &node_time_dict) { const c10::Dict<node_t, torch::Tensor> &node_time_dict) {
//bool temporal = (!node_time_dict.empty());
// Create a mapping to convert single string relations to edge type triplets: // Create a mapping to convert single string relations to edge type triplets:
unordered_map<rel_t, edge_t> to_edge_type; unordered_map<rel_t, edge_t> to_edge_type;
for (const auto &k : edge_types) for (const auto &k : edge_types)
...@@ -174,11 +170,12 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -174,11 +170,12 @@ hetero_sample(const vector<node_t> &node_types,
const torch::Tensor &input_node = kv.value(); const torch::Tensor &input_node = kv.value();
const auto *input_node_data = input_node.data_ptr<int64_t>(); const auto *input_node_data = input_node.data_ptr<int64_t>();
// dummy value. will be reset to root time if is_temporal==true // dummy value. will be reset to root time if is_temporal==true
auto *node_time_data = input_node.data_ptr<int64_t>(); int64_t *node_time_data;
// root_time[i] stores the timestamp of the computation tree root // root_time[i] stores the timestamp of the computation tree root
// of the node samples[i] // of the node samples[i]
if (temporal) { if (temporal) {
node_time_data = node_time_dict.at(node_type).data_ptr<int64_t>(); torch::Tensor node_time = node_time_dict.at(node_type);
node_time_data = node_time.data_ptr<int64_t>();
} }
auto &samples = samples_dict.at(node_type); auto &samples = samples_dict.at(node_type);
...@@ -220,7 +217,7 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -220,7 +217,7 @@ hetero_sample(const vector<node_t> &node_types,
const auto &begin = slice_dict.at(dst_node_type).first; const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second; const auto &end = slice_dict.at(dst_node_type).second;
if (begin == end){ if (begin == end) {
continue; continue;
} }
// for temporal sampling, sampled src node cannot have timestamp greater // for temporal sampling, sampled src node cannot have timestamp greater
...@@ -370,22 +367,17 @@ hetero_sample(const vector<node_t> &node_types, ...@@ -370,22 +367,17 @@ hetero_sample(const vector<node_t> &node_types,
template <bool replace, bool directed> template <bool replace, bool directed>
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>, tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>> c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample_random(const vector<node_t> &node_types, hetero_sample_random(
const vector<edge_t> &edge_types, const vector<node_t> &node_types, const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict, const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict, const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict, const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict, const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops) { const int64_t num_hops) {
c10::Dict<node_t, torch::Tensor> empty_dict; c10::Dict<node_t, torch::Tensor> empty_dict;
return hetero_sample<replace, directed, false>(node_types, return hetero_sample<replace, directed, false>(
edge_types, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
colptr_dict, num_neighbors_dict, num_hops, empty_dict);
row_dict,
input_node_dict,
num_neighbors_dict,
num_hops,
empty_dict);
} }
} // namespace } // namespace
...@@ -418,24 +410,20 @@ hetero_neighbor_sample_cpu( ...@@ -418,24 +410,20 @@ hetero_neighbor_sample_cpu(
const int64_t num_hops, const bool replace, const bool directed) { const int64_t num_hops, const bool replace, const bool directed) {
if (replace && directed) { if (replace && directed) {
return hetero_sample_random<true, true>( return hetero_sample_random<true, true>(node_types, edge_types, colptr_dict,
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
row_dict, input_node_dict, num_neighbors_dict, num_hops);
num_neighbors_dict, num_hops);
} else if (replace && !directed) { } else if (replace && !directed) {
return hetero_sample_random<true, false>( return hetero_sample_random<true, false>(
node_types, edge_types, colptr_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops); num_neighbors_dict, num_hops);
} else if (!replace && directed) { } else if (!replace && directed) {
return hetero_sample_random<false, true>( return hetero_sample_random<false, true>(
node_types, edge_types, colptr_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops); num_neighbors_dict, num_hops);
} else { } else {
return hetero_sample_random<false, false>( return hetero_sample_random<false, false>(
node_types, edge_types, colptr_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops); num_neighbors_dict, num_hops);
} }
} }
...@@ -453,23 +441,19 @@ hetero_neighbor_temporal_sample_cpu( ...@@ -453,23 +441,19 @@ hetero_neighbor_temporal_sample_cpu(
if (replace && directed) { if (replace && directed) {
return hetero_sample<true, true, true>( return hetero_sample<true, true, true>(
node_types, edge_types, colptr_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict); num_neighbors_dict, num_hops, node_time_dict);
} else if (replace && !directed) { } else if (replace && !directed) {
return hetero_sample<true, false, true>( return hetero_sample<true, false, true>(
node_types, edge_types, colptr_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict); num_neighbors_dict, num_hops, node_time_dict);
} else if (!replace && directed) { } else if (!replace && directed) {
return hetero_sample<false, true, true>( return hetero_sample<false, true, true>(
node_types, edge_types, colptr_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict); num_neighbors_dict, num_hops, node_time_dict);
} else { } else {
return hetero_sample<false, false, true>( return hetero_sample<false, false, true>(
node_types, edge_types, colptr_dict, node_types, edge_types, colptr_dict, row_dict, input_node_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops, node_time_dict); num_neighbors_dict, num_hops, node_time_dict);
} }
} }
...@@ -17,3 +17,8 @@ test = pytest ...@@ -17,3 +17,8 @@ test = pytest
[tool:pytest] [tool:pytest]
addopts = --capture=no addopts = --capture=no
[isort]
multi_line_output=3
include_trailing_comma = True
skip=.gitignore,__init__.py
...@@ -11,7 +11,7 @@ from torch.__config__ import parallel_info ...@@ -11,7 +11,7 @@ from torch.__config__ import parallel_info
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension, from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension) CUDAExtension)
__version__ = '0.6.13' __version__ = '0.7.0'
URL = 'https://github.com/rusty1s/pytorch_sparse' URL = 'https://github.com/rusty1s/pytorch_sparse'
WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None WITH_CUDA = torch.cuda.is_available() and CUDA_HOME is not None
......
...@@ -3,7 +3,7 @@ import os.path as osp ...@@ -3,7 +3,7 @@ import os.path as osp
import torch import torch
__version__ = '0.6.13' __version__ = '0.7.0'
for library in [ for library in [
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw', '_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis', '_rw',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment