Commit 2951b12d authored by aiss's avatar aiss
Browse files

push v0.6.18 version

parent e8309f27
cmake_minimum_required(VERSION 3.10)
project(torchsparse)
set(CMAKE_CXX_STANDARD 14)
set(TORCHSPARSE_VERSION 0.6.18)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_PYTHON "Link to Python when building" ON)
option(WITH_METIS "Enable METIS support" OFF)
if(WITH_CUDA)
enable_language(CUDA)
add_definitions(-D__CUDA_NO_HALF_OPERATORS__)
add_definitions(-DWITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
endif()
if (WITH_PYTHON)
add_definitions(-DWITH_PYTHON)
find_package(Python3 COMPONENTS Development)
endif()
find_package(Torch REQUIRED)
if (WITH_METIS)
add_definitions(-DWITH_METIS)
find_package(METIS)
endif()
file(GLOB HEADERS csrc/*.h)
file(GLOB OPERATOR_SOURCES csrc/*.* csrc/cpu/*.*)
if(WITH_CUDA)
file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} csrc/cuda/*.h csrc/cuda/*.cu)
endif()
add_library(${PROJECT_NAME} SHARED ${OPERATOR_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})
if (WITH_PYTHON)
target_link_libraries(${PROJECT_NAME} PRIVATE Python3::Python)
endif()
if (WITH_METIS)
target_include_directories(${PROJECT_NAME} PRIVATE ${METIS_INCLUDE_DIRS})
target_link_libraries(${PROJECT_NAME} PRIVATE ${METIS_LIBRARIES})
endif()
find_package(OpenMP)
if (OPENMP_FOUND)
set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}")
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
# set (CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=${OpenMP_CXX_FLAGS}")
set (CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}")
endif()
set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchSparse)
target_include_directories(${PROJECT_NAME} INTERFACE
"$<BUILD_INTERFACE:${HEADERS}>"
$<INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>)
include(GNUInstallDirs)
include(CMakePackageConfigHelpers)
set(PHMAP_DIR third_party/parallel-hashmap)
target_include_directories(${PROJECT_NAME} PRIVATE ${PHMAP_DIR})
set(TORCHSPARSE_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchSparse" CACHE STRING "install path for TorchSparseConfig.cmake")
configure_package_config_file(cmake/TorchSparseConfig.cmake.in
"${CMAKE_CURRENT_BINARY_DIR}/TorchSparseConfig.cmake"
INSTALL_DESTINATION ${TORCHSPARSE_CMAKECONFIG_INSTALL_DIR})
write_basic_package_version_file(${CMAKE_CURRENT_BINARY_DIR}/TorchSparseConfigVersion.cmake
VERSION ${TORCHSPARSE_VERSION}
COMPATIBILITY AnyNewerVersion)
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/TorchSparseConfig.cmake
${CMAKE_CURRENT_BINARY_DIR}/TorchSparseConfigVersion.cmake
DESTINATION ${TORCHSPARSE_CMAKECONFIG_INSTALL_DIR})
install(TARGETS ${PROJECT_NAME}
EXPORT TorchSparseTargets
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
)
install(EXPORT TorchSparseTargets
NAMESPACE TorchSparse::
DESTINATION ${TORCHSPARSE_CMAKECONFIG_INSTALL_DIR})
install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME})
install(FILES
csrc/cpu/convert_cpu.h
csrc/cpu/diag_cpu.h
csrc/cpu/metis_cpu.h
csrc/cpu/rw_cpu.h
csrc/cpu/saint_cpu.h
csrc/cpu/sample_cpu.h
csrc/cpu/spmm_cpu.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cpu)
if(WITH_CUDA)
install(FILES
csrc/cuda/convert_cuda.h
csrc/cuda/diag_cuda.h
csrc/cuda/rw_cuda.h
csrc/cuda/spmm_cuda.h
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cuda)
endif()
if(WITH_CUDA)
set_property(TARGET torch_cuda PROPERTY INTERFACE_COMPILE_OPTIONS "")
set_property(TARGET torch_cpu PROPERTY INTERFACE_COMPILE_OPTIONS "")
endif()
include README.md
include LICENSE
recursive-exclude test *
recursive-include csrc *
recursive-include third_party *
recursive-exclude third_party/parallel-hashmap/css *
recursive-exclude third_party/parallel-hashmap/html *
recursive-exclude third_party/parallel-hashmap/tests *
recursive-exclude third_party/parallel-hashmap/examples *
recursive-exclude third_party/parallel-hashmap/benchmark *
recursive-exclude test *
recursive-exclude benchmark *
......@@ -3,8 +3,8 @@
torch-sparse是PyTorch的一个扩展库,用于处理稀疏数据。它提供了一些功能强大的稀疏矩阵操作,以及能够高效执行稀疏计算的函数和工具。
## 依赖安装
+ pytorch1.10或者pytorch1.13 以及对应的torchvision(建议dtk-22.04.2、dtk-23.04与dtk-23.10)
+ python 3.7-3.10
+ 建议pytorch1.13以及对应的torchvision(建议dtk-23.10)
+ python 3.8及以上
### 1、使用源码编译方式安装
......@@ -27,20 +27,15 @@ git clone http://developer.hpccube.com/codes/aicomponent/torch-sparce # 根据
```
- 源码编译(进入torch-sparse目录):
```
export C_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/gflags-2.1.2-build/include:$C_INCLUDE_PATH
export CPLUS_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/gflags-2.1.2-build/include:$CPLUS_INCLUDE_PATH
export C_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/glog-build/include:$C_INCLUDE_PATH
export CPLUS_INCLUDE_PATH=/public/software/apps/DeepLearning/PyTorch_Lib/glog-build/include:$CPLUS_INCLUDE_PATH
export C_INCLUDE_PATH=$ROCM_PATH/rocrand/include:$C_INCLUDE_PATH
export CPLUS_INCLUDE_PATH=$ROCM_PATH/rocrand/include:$CPLUS_INCLUDE_PATH
export LD_LIBRARY_PATH=$ROCM_PATH/rocrand/lib:$LD_LIBRARY_PATH
export FORCE_ONLY_HIP=1
source /opt/dtk/env.sh
export FORCE_CUDA=1
export CC=hipcc
export CXX=hipcc
python setup.py install
python setup.py install bdist_wheel
```
#### 注意事项
+ 编译后生成的安装包whl在dist文件夹下,pip直接安装即可
+ 若使用pip install下载安装过慢,可添加pypi清华源:-i https://pypi.tuna.tsinghua.edu.cn/simple/
+ ROCM_PATH为dtk的路径,默认为/opt/dtk
......@@ -74,4 +69,4 @@ tensor([[6.0, 8.0],
## 参考资料
- [README_ORIGIN](README_ORIGIN.md)
- [https://pypi.org/project/torch-sparse/0.6.13/](https://pypi.org/project/torch-sparse/0.6.13/)
- [https://github.com/rusty1s/pytorch_sparse/tree/0.6.18](https://github.com/rusty1s/pytorch_sparse/tree/0.6.18)
......@@ -43,40 +43,40 @@ conda install pytorch-sparse -c pyg
We alternatively provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).
#### PyTorch 1.11
#### PyTorch 2.1
To install the binaries for PyTorch 1.11.0, simply run
To install the binaries for PyTorch 2.1.0, simply run
```
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.11.0+${CUDA}.html
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html
```
where `${CUDA}` should be replaced by either `cpu`, `cu102`, `cu113`, or `cu115` depending on your PyTorch installation.
where `${CUDA}` should be replaced by either `cpu`, `cu118`, or `cu121` depending on your PyTorch installation.
| | `cpu` | `cu102` | `cu113` | `cu115` |
|-------------|-------|---------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ | ✅ |
| **Windows** | ✅ | | ✅ | ✅ |
| **macOS** | ✅ | | | |
| | `cpu` | `cu118` | `cu121` |
|-------------|-------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | |
#### PyTorch 1.10
#### PyTorch 2.0
To install the binaries for PyTorch 1.10.0, PyTorch 1.10.1 and PyTorch 1.10.2, simply run
To install the binaries for PyTorch 2.0.0, simply run
```
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+${CUDA}.html
pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html
```
where `${CUDA}` should be replaced by either `cpu`, `cu102`, `cu111`, or `cu113` depending on your PyTorch installation.
where `${CUDA}` should be replaced by either `cpu`, `cu117`, or `cu118` depending on your PyTorch installation.
| | `cpu` | `cu102` | `cu111` | `cu113` |
|-------------|-------|---------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | | |
| | `cpu` | `cu117` | `cu118` |
|-------------|-------|---------|---------|
| **Linux** | ✅ | ✅ | ✅ |
| **Windows** | ✅ | ✅ | ✅ |
| **macOS** | ✅ | | |
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1 and PyTorch 1.9.0 (following the same procedure).
For older versions, you might need to explicitly specify the latest supported version number in order to prevent a manual installation from source.
**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1 and PyTorch 1.13.0/1.13.1 (following the same procedure).
For older versions, you need to explicitly specify the latest supported version number or install via `pip install --no-index` in order to prevent a manual installation from source.
You can look up the latest supported version number [here](https://data.pyg.org/whl).
### From source
......@@ -94,7 +94,7 @@ $ echo $CPATH
>>> /usr/local/cuda/include:...
```
If you want to additionally build `torch-sparse` with METIS support, *e.g.* for partioning, please download and install the [METIS library](http://glaros.dtc.umn.edu/gkhome/metis/metis/download) by following the instructions in the `Install.txt` file.
If you want to additionally build `torch-sparse` with METIS support, *e.g.* for partioning, please download and install the [METIS library](https://web.archive.org/web/20211119110155/http://glaros.dtc.umn.edu/gkhome/metis/metis/download) by following the instructions in the `Install.txt` file.
Note that METIS needs to be installed with 64 bit `IDXTYPEWIDTH` by changing `include/metis.h`.
Afterwards, set the environment variable `WITH_METIS=1`.
......@@ -294,21 +294,22 @@ print(valueC)
tensor([8.0, 6.0, 8.0])
```
## Running tests
```
pytest
```
## C++ API
`torch-sparse` also offers a C++ API that contains C++ equivalent of python models.
For this, we need to add `TorchLib` to the `-DCMAKE_PREFIX_PATH` (*e.g.*, it may exists in `{CONDA}/lib/python{X.X}/site-packages/torch` if installed via `conda`):
```
mkdir build
cd build
# Add -DWITH_CUDA=on support for the CUDA if needed
cmake ..
# Add -DWITH_CUDA=on support for CUDA support
cmake -DCMAKE_PREFIX_PATH="..." ..
make
make install
```
## Running tests
```
pytest
```
import argparse
import itertools
import os.path as osp
import time
import torch
import wget
from scipy.io import loadmat
from torch_scatter import scatter_add
from torch_sparse.tensor import SparseTensor
short_rows = [
('DIMACS10', 'citationCiteseer'),
('SNAP', 'web-Stanford'),
]
long_rows = [
('Janna', 'StocF-1465'),
('GHS_psdef', 'ldoor'),
]
def download(dataset):
url = 'https://sparse.tamu.edu/mat/{}/{}.mat'
for group, name in itertools.chain(long_rows, short_rows):
if not osp.exists(f'{name}.mat'):
print(f'Downloading {group}/{name}:')
wget.download(url.format(group, name))
print('')
def bold(text, flag=True):
return f'\033[1m{text}\033[0m' if flag else text
@torch.no_grad()
def correctness(dataset):
group, name = dataset
mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape)
mat.fill_cache_()
mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce()
for size in sizes:
try:
x = torch.randn((mat.size(1), size), device=args.device)
out1 = mat @ x
out2 = mat_pytorch @ x
assert torch.allclose(out1, out2, atol=1e-4)
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
def time_func(func, x):
try:
if torch.cuda.is_available():
torch.cuda.synchronize()
elif torch.backends.mps.is_available():
import torch.mps
torch.mps.synchronize()
t = time.perf_counter()
if not args.with_backward:
with torch.no_grad():
for _ in range(iters):
func(x)
else:
x = x.requires_grad_()
for _ in range(iters):
out = func(x)
out = out[0] if isinstance(out, tuple) else out
torch.autograd.grad(out, x, out, only_inputs=True)
if torch.cuda.is_available():
torch.cuda.synchronize()
elif torch.backends.mps.is_available():
import torch.mps
torch.mps.synchronize()
return time.perf_counter() - t
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
return float('inf')
def timing(dataset):
group, name = dataset
mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape)
mat.fill_cache_()
mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce()
mat_scipy = mat.to_scipy(layout='csr')
def scatter(x):
return scatter_add(x[col], row, dim=0, dim_size=mat_scipy.shape[0])
def spmm_scipy(x):
if x.is_cuda:
raise RuntimeError('out of memory')
return mat_scipy @ x
def spmm_pytorch(x):
return mat_pytorch @ x
def spmm(x):
return mat @ x
t1, t2, t3, t4 = [], [], [], []
for size in sizes:
try:
x = torch.randn((mat.size(1), size), device=args.device)
t1 += [time_func(scatter, x)]
t2 += [time_func(spmm_scipy, x)]
t3 += [time_func(spmm_pytorch, x)]
t4 += [time_func(spmm, x)]
del x
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
for t in (t1, t2, t3, t4):
t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4])
winner = torch.zeros_like(ts, dtype=torch.bool)
winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
winner = winner.tolist()
name = f'{group}/{name}'
print(f'{bold(name)} (avg row length: {mat.avg_row_length():.2f}):')
print('\t'.join([' '] + [f'{size:>5}' for size in sizes]))
print('\t'.join([bold('Scatter ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
print('\t'.join([bold('SPMM SciPy ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
print('\t'.join([bold('SPMM PyTorch')] +
[bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
print('\t'.join([bold('SPMM Own ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:4] if args.device == 'cpu' else sizes
for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
timing(dataset)
import os
import wget
import time
import errno
import argparse
import os.path as osp
import torch
from scipy.io import loadmat
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, default='/tmp/test_ptr2ind')
args = parser.parse_args()
matrices = [
('DIMACS10', 'citationCiteseer'),
('SNAP', 'web-Stanford'),
('Janna', 'StocF-1465'),
('GHS_psdef', 'ldoor'),
]
def get_torch_sparse_coo_tensor(root, group, name):
path = osp.join(root, f'{name}.mat')
if not osp.exists(path):
try:
os.makedirs(root)
except OSError as e:
if e.errno != errno.EEXIST and osp.isdir(path):
raise e
url = f'https://sparse.tamu.edu/mat/{group}/{name}.mat'
print(f'Downloading {group}/{name}:')
wget.download(url, path)
matrix = loadmat(path)['Problem'][0][0][2].tocoo()
row = torch.from_numpy(matrix.row).to(torch.long)
col = torch.from_numpy(matrix.col).to(torch.long)
index = torch.stack([row, col], dim=0)
value = torch.from_numpy(matrix.data).to(torch.float)
print(f'{name}.mat: shape={matrix.shape} nnz={row.numel()}')
return torch.sparse_coo_tensor(index, value, matrix.shape).coalesce()
def time_func(matrix, op, duration=5.0, warmup=1.0):
t = time.time()
while (time.time() - t) < warmup:
op(matrix)
torch.cuda.synchronize()
count = 0
t = time.time()
while (time.time() - t) < duration:
op(matrix)
count += 1
torch.cuda.synchronize()
return (time.time() - t) / count
def bucketize(matrix):
row_indices = matrix.indices()[0]
arange = torch.arange(matrix.size(0) + 1, device=row_indices.device)
return torch.bucketize(arange, row_indices)
def convert_coo_to_csr(matrix):
row_indices = matrix.indices()[0]
return torch._convert_coo_to_csr(row_indices, matrix.size(0))
for device in ['cpu', 'cuda']:
print('DEVICE:', device)
for group, name in matrices:
matrix = get_torch_sparse_coo_tensor(args.root, group, name)
matrix = matrix.to(device)
out1 = bucketize(matrix)
out2 = convert_coo_to_csr(matrix)
assert out1.tolist() == out2.tolist()
t = time_func(matrix, bucketize, duration=5.0, warmup=1.0)
print('old impl:', t)
t = time_func(matrix, convert_coo_to_csr, duration=5.0, warmup=1.0)
print('new impl:', t)
print()
###
#
# @copyright (c) 2009-2014 The University of Tennessee and The University
# of Tennessee Research Foundation.
# All rights reserved.
# @copyright (c) 2012-2014 Inria. All rights reserved.
# @copyright (c) 2012-2014 Bordeaux INP, CNRS (LaBRI UMR 5800), Inria, Univ. Bordeaux. All rights reserved.
#
###
#
# - Find METIS include dirs and libraries
# Use this module by invoking find_package with the form:
# find_package(METIS
# [REQUIRED] # Fail with error if metis is not found
# )
#
# This module finds headers and metis library.
# Results are reported in variables:
# METIS_FOUND - True if headers and requested libraries were found
# METIS_INCLUDE_DIRS - metis include directories
# METIS_LIBRARY_DIRS - Link directories for metis libraries
# METIS_LIBRARIES - metis component libraries to be linked
#
# The user can give specific paths where to find the libraries adding cmake
# options at configure (ex: cmake path/to/project -DMETIS_DIR=path/to/metis):
# METIS_DIR - Where to find the base directory of metis
# METIS_INCDIR - Where to find the header files
# METIS_LIBDIR - Where to find the library files
# The module can also look for the following environment variables if paths
# are not given as cmake variable: METIS_DIR, METIS_INCDIR, METIS_LIBDIR
#=============================================================================
# Copyright 2012-2013 Inria
# Copyright 2012-2013 Emmanuel Agullo
# Copyright 2012-2013 Mathieu Faverge
# Copyright 2012 Cedric Castagnede
# Copyright 2013 Florent Pruvost
#
# Distributed under the OSI-approved BSD License (the "License");
# see accompanying file MORSE-Copyright.txt for details.
#
# This software is distributed WITHOUT ANY WARRANTY; without even the
# implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the License for more information.
#=============================================================================
# (To distribute this file outside of Morse, substitute the full
# License text for the above reference.)
if (NOT METIS_FOUND)
set(METIS_DIR "" CACHE PATH "Installation directory of METIS library")
if (NOT METIS_FIND_QUIETLY)
message(STATUS "A cache variable, namely METIS_DIR, has been set to specify the install directory of METIS")
endif()
endif()
# Looking for include
# -------------------
# Add system include paths to search include
# ------------------------------------------
unset(_inc_env)
set(ENV_METIS_DIR "$ENV{METIS_DIR}")
set(ENV_METIS_INCDIR "$ENV{METIS_INCDIR}")
if(ENV_METIS_INCDIR)
list(APPEND _inc_env "${ENV_METIS_INCDIR}")
elseif(ENV_METIS_DIR)
list(APPEND _inc_env "${ENV_METIS_DIR}")
list(APPEND _inc_env "${ENV_METIS_DIR}/include")
list(APPEND _inc_env "${ENV_METIS_DIR}/include/metis")
else()
if(WIN32)
string(REPLACE ":" ";" _inc_env "$ENV{INCLUDE}")
else()
string(REPLACE ":" ";" _path_env "$ENV{INCLUDE}")
list(APPEND _inc_env "${_path_env}")
string(REPLACE ":" ";" _path_env "$ENV{C_INCLUDE_PATH}")
list(APPEND _inc_env "${_path_env}")
string(REPLACE ":" ";" _path_env "$ENV{CPATH}")
list(APPEND _inc_env "${_path_env}")
string(REPLACE ":" ";" _path_env "$ENV{INCLUDE_PATH}")
list(APPEND _inc_env "${_path_env}")
endif()
endif()
list(APPEND _inc_env "${CMAKE_PLATFORM_IMPLICIT_INCLUDE_DIRECTORIES}")
list(APPEND _inc_env "${CMAKE_C_IMPLICIT_INCLUDE_DIRECTORIES}")
list(REMOVE_DUPLICATES _inc_env)
# Try to find the metis header in the given paths
# -------------------------------------------------
# call cmake macro to find the header path
if(METIS_INCDIR)
set(METIS_metis.h_DIRS "METIS_metis.h_DIRS-NOTFOUND")
find_path(METIS_metis.h_DIRS
NAMES metis.h
HINTS ${METIS_INCDIR})
else()
if(METIS_DIR)
set(METIS_metis.h_DIRS "METIS_metis.h_DIRS-NOTFOUND")
find_path(METIS_metis.h_DIRS
NAMES metis.h
HINTS ${METIS_DIR}
PATH_SUFFIXES "include" "include/metis")
else()
set(METIS_metis.h_DIRS "METIS_metis.h_DIRS-NOTFOUND")
find_path(METIS_metis.h_DIRS
NAMES metis.h
HINTS ${_inc_env})
endif()
endif()
mark_as_advanced(METIS_metis.h_DIRS)
# If found, add path to cmake variable
# ------------------------------------
if (METIS_metis.h_DIRS)
set(METIS_INCLUDE_DIRS "${METIS_metis.h_DIRS}")
else ()
set(METIS_INCLUDE_DIRS "METIS_INCLUDE_DIRS-NOTFOUND")
if(NOT METIS_FIND_QUIETLY)
message(STATUS "Looking for metis -- metis.h not found")
endif()
endif()
# Looking for lib
# ---------------
# Add system library paths to search lib
# --------------------------------------
unset(_lib_env)
set(ENV_METIS_LIBDIR "$ENV{METIS_LIBDIR}")
if(ENV_METIS_LIBDIR)
list(APPEND _lib_env "${ENV_METIS_LIBDIR}")
elseif(ENV_METIS_DIR)
list(APPEND _lib_env "${ENV_METIS_DIR}")
list(APPEND _lib_env "${ENV_METIS_DIR}/lib")
else()
if(WIN32)
string(REPLACE ":" ";" _lib_env "$ENV{LIB}")
else()
if(APPLE)
string(REPLACE ":" ";" _lib_env "$ENV{DYLD_LIBRARY_PATH}")
else()
string(REPLACE ":" ";" _lib_env "$ENV{LD_LIBRARY_PATH}")
endif()
list(APPEND _lib_env "${CMAKE_PLATFORM_IMPLICIT_LINK_DIRECTORIES}")
list(APPEND _lib_env "${CMAKE_C_IMPLICIT_LINK_DIRECTORIES}")
endif()
endif()
list(REMOVE_DUPLICATES _lib_env)
# Try to find the metis lib in the given paths
# ----------------------------------------------
# call cmake macro to find the lib path
if(METIS_LIBDIR)
set(METIS_metis_LIBRARY "METIS_metis_LIBRARY-NOTFOUND")
find_library(METIS_metis_LIBRARY
NAMES metis
HINTS ${METIS_LIBDIR})
else()
if(METIS_DIR)
set(METIS_metis_LIBRARY "METIS_metis_LIBRARY-NOTFOUND")
find_library(METIS_metis_LIBRARY
NAMES metis
HINTS ${METIS_DIR}
PATH_SUFFIXES lib lib32 lib64)
else()
set(METIS_metis_LIBRARY "METIS_metis_LIBRARY-NOTFOUND")
find_library(METIS_metis_LIBRARY
NAMES metis
HINTS ${_lib_env})
endif()
endif()
mark_as_advanced(METIS_metis_LIBRARY)
# If found, add path to cmake variable
# ------------------------------------
if (METIS_metis_LIBRARY)
get_filename_component(metis_lib_path "${METIS_metis_LIBRARY}" PATH)
# set cmake variables
set(METIS_LIBRARIES "${METIS_metis_LIBRARY}")
set(METIS_LIBRARY_DIRS "${metis_lib_path}")
else ()
set(METIS_LIBRARIES "METIS_LIBRARIES-NOTFOUND")
set(METIS_LIBRARY_DIRS "METIS_LIBRARY_DIRS-NOTFOUND")
if(NOT METIS_FIND_QUIETLY)
message(STATUS "Looking for metis -- lib metis not found")
endif()
endif ()
# check a function to validate the find
if(METIS_LIBRARIES)
set(REQUIRED_INCDIRS)
set(REQUIRED_LIBDIRS)
set(REQUIRED_LIBS)
# METIS
if (METIS_INCLUDE_DIRS)
set(REQUIRED_INCDIRS "${METIS_INCLUDE_DIRS}")
endif()
if (METIS_LIBRARY_DIRS)
set(REQUIRED_LIBDIRS "${METIS_LIBRARY_DIRS}")
endif()
set(REQUIRED_LIBS "${METIS_LIBRARIES}")
# m
find_library(M_LIBRARY NAMES m)
mark_as_advanced(M_LIBRARY)
if(M_LIBRARY)
list(APPEND REQUIRED_LIBS "-lm")
endif()
# set required libraries for link
set(CMAKE_REQUIRED_INCLUDES "${REQUIRED_INCDIRS}")
set(CMAKE_REQUIRED_LIBRARIES)
foreach(lib_dir ${REQUIRED_LIBDIRS})
list(APPEND CMAKE_REQUIRED_LIBRARIES "-L${lib_dir}")
endforeach()
list(APPEND CMAKE_REQUIRED_LIBRARIES "${REQUIRED_LIBS}")
string(REGEX REPLACE "^ -" "-" CMAKE_REQUIRED_LIBRARIES "${CMAKE_REQUIRED_LIBRARIES}")
# test link
unset(METIS_WORKS CACHE)
include(CheckFunctionExists)
check_function_exists(METIS_NodeND METIS_WORKS)
mark_as_advanced(METIS_WORKS)
if(NOT METIS_WORKS)
if(NOT METIS_FIND_QUIETLY)
message(STATUS "Looking for METIS : test of METIS_NodeND with METIS library fails")
message(STATUS "CMAKE_REQUIRED_LIBRARIES: ${CMAKE_REQUIRED_LIBRARIES}")
message(STATUS "CMAKE_REQUIRED_INCLUDES: ${CMAKE_REQUIRED_INCLUDES}")
message(STATUS "Check in CMakeFiles/CMakeError.log to figure out why it fails")
endif()
endif()
set(CMAKE_REQUIRED_INCLUDES)
set(CMAKE_REQUIRED_FLAGS)
set(CMAKE_REQUIRED_LIBRARIES)
endif()
if (METIS_LIBRARIES)
list(GET METIS_LIBRARIES 0 first_lib)
get_filename_component(first_lib_path "${first_lib}" PATH)
if (${first_lib_path} MATCHES "/lib(32|64)?$")
string(REGEX REPLACE "/lib(32|64)?$" "" not_cached_dir "${first_lib_path}")
set(METIS_DIR_FOUND "${not_cached_dir}" CACHE PATH "Installation directory of METIS library" FORCE)
else()
set(METIS_DIR_FOUND "${first_lib_path}" CACHE PATH "Installation directory of METIS library" FORCE)
endif()
endif()
mark_as_advanced(METIS_DIR)
mark_as_advanced(METIS_DIR_FOUND)
# check that METIS has been found
# ---------------------------------
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(METIS DEFAULT_MSG
METIS_LIBRARIES
METIS_WORKS
METIS_INCLUDE_DIRS)
#
# TODO: Add possibility to check for specific functions in the library
#
# TorchSparseConfig.cmake
# --------------------
#
# Exported targets:: Sparse
#
@PACKAGE_INIT@
set(PN TorchSparse)
set(${PN}_INCLUDE_DIR "${PACKAGE_PREFIX_DIR}/@CMAKE_INSTALL_INCLUDEDIR@")
set(${PN}_LIBRARY "")
set(${PN}_DEFINITIONS USING_${PN})
check_required_components(${PN})
if(NOT (CMAKE_VERSION VERSION_LESS 3.0))
#-----------------------------------------------------------------------------
# Don't include targets if this file is being picked up by another
# project which has already built this as a subproject
#-----------------------------------------------------------------------------
if(NOT TARGET ${PN}::TorchSparse)
include("${CMAKE_CURRENT_LIST_DIR}/${PN}Targets.cmake")
if(NOT TARGET torch_library)
find_package(Torch REQUIRED)
endif()
if(NOT TARGET Python3::Python)
find_package(Python3 COMPONENTS Development)
endif()
target_link_libraries(TorchSparse::TorchSparse INTERFACE ${TORCH_LIBRARIES} Python3::Python)
if(@WITH_CUDA@)
target_compile_definitions(TorchSparse::TorchSparse INTERFACE WITH_CUDA)
endif()
endif()
endif()
```
./build_conda.sh 3.9 2.1.0 cu118 # python, pytorch and cuda version
```
copy "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC\\Tools\\MSVC\\14.29.30133\\lib\\x64\\metis.lib" %LIBRARY_LIB%
if errorlevel 1 exit 1
copy "C:\\Program Files (x86)\\Microsoft Visual Studio\\2019\\Enterprise\\VC\\Tools\\MSVC\\14.29.30133\\include\\metis.h" %LIBRARY_INC%
if errorlevel 1 exit 1
"%PYTHON%" -m pip install .
if errorlevel 1 exit 1
$PYTHON -m pip install .
#!/bin/bash
export PYTHON_VERSION=$1
export TORCH_VERSION=$2
export CUDA_VERSION=$3
export CONDA_PYTORCH_CONSTRAINT="pytorch==${TORCH_VERSION%.*}.*"
if [ "${CUDA_VERSION}" = "cpu" ]; then
export CONDA_CUDATOOLKIT_CONSTRAINT="cpuonly # [not osx]"
else
case $CUDA_VERSION in
cu121)
export CONDA_CUDATOOLKIT_CONSTRAINT="pytorch-cuda==12.1.*"
;;
cu118)
export CONDA_CUDATOOLKIT_CONSTRAINT="pytorch-cuda==11.8.*"
;;
cu117)
export CONDA_CUDATOOLKIT_CONSTRAINT="pytorch-cuda==11.7.*"
;;
cu116)
if [ "${TORCH_VERSION}" = "1.12.0" ]; then
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==11.6.*"
else
export CONDA_CUDATOOLKIT_CONSTRAINT="pytorch-cuda==11.6.*"
fi
;;
cu115)
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==11.5.*"
;;
cu113)
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==11.3.*"
;;
cu111)
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==11.1.*"
;;
cu102)
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==10.2.*"
;;
cu101)
export CONDA_CUDATOOLKIT_CONSTRAINT="cudatoolkit==10.1.*"
;;
*)
echo "Unrecognized CUDA_VERSION=$CUDA_VERSION"
exit 1
;;
esac
fi
echo "PyTorch $TORCH_VERSION+$CUDA_VERSION"
echo "- $CONDA_PYTORCH_CONSTRAINT"
echo "- $CONDA_CUDATOOLKIT_CONSTRAINT"
if [ "${TORCH_VERSION}" = "1.12.0" ] && [ "${CUDA_VERSION}" = "cu116" ]; then
conda build . -c pytorch -c pyg -c default -c nvidia -c conda-forge --output-folder "$HOME/conda-bld"
else
conda build . -c pytorch -c pyg -c default -c nvidia --output-folder "$HOME/conda-bld"
fi
package:
name: pytorch-sparse
version: 0.6.18
source:
path: ../..
requirements:
build:
- {{ compiler('c') }} # [win]
host:
- pip
- python {{ environ.get('PYTHON_VERSION') }}
- {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
- {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
run:
- scipy
- pytorch-scatter
- python {{ environ.get('PYTHON_VERSION') }}
- {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
- {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
build:
string: py{{ environ.get('PYTHON_VERSION').replace('.', '') }}_torch_{{ environ['TORCH_VERSION'] }}_{{ environ['CUDA_VERSION'] }}
script_env:
- FORCE_CUDA
- TORCH_CUDA_ARCH_LIST
- WITH_METIS=1
preserve_egg_dir: True
test:
imports:
- torch_sparse
about:
home: https://github.com/rusty1s/pytorch_sparse
license: MIT
summary: PyTorch Extension Library of Optimized Autograd Sparse Matrix Operations
......@@ -5,13 +5,13 @@
#include "cpu/convert_cpu.h"
#ifdef WITH_HIP
#include "hip/convert_hip.h"
#ifdef WITH_CUDA
#include "cuda/convert_cuda.h"
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_HIP
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__convert_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__convert_cpu(void) { return NULL; }
......@@ -21,7 +21,7 @@ PyMODINIT_FUNC PyInit__convert_cpu(void) { return NULL; }
SPARSE_API torch::Tensor ind2ptr(torch::Tensor ind, int64_t M) {
if (ind.device().is_cuda()) {
#ifdef WITH_HIP
#ifdef WITH_CUDA
return ind2ptr_cuda(ind, M);
#else
AT_ERROR("Not compiled with CUDA support");
......@@ -33,7 +33,7 @@ SPARSE_API torch::Tensor ind2ptr(torch::Tensor ind, int64_t M) {
SPARSE_API torch::Tensor ptr2ind(torch::Tensor ptr, int64_t E) {
if (ptr.device().is_cuda()) {
#ifdef WITH_HIP
#ifdef WITH_CUDA
return ptr2ind_cuda(ptr, E);
#else
AT_ERROR("Not compiled with CUDA support");
......
......@@ -13,7 +13,7 @@ torch::Tensor non_diag_mask_cpu(torch::Tensor row, torch::Tensor col, int64_t M,
auto row_data = row.data_ptr<int64_t>();
auto col_data = col.data_ptr<int64_t>();
auto mask = torch::zeros(E + num_diag, row.options().dtype(torch::kBool));
auto mask = torch::zeros({E + num_diag}, row.options().dtype(torch::kBool));
auto mask_data = mask.data_ptr<bool>();
int64_t r, c;
......
......@@ -19,8 +19,6 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
torch::Tensor idx, int64_t depth,
int64_t num_neighbors, bool replace) {
srand(time(NULL) + 1000 * getpid()); // Initialize random seed.
std::vector<torch::Tensor> out_rowptrs(idx.numel() + 1);
std::vector<torch::Tensor> out_cols(idx.numel());
std::vector<torch::Tensor> out_n_ids(idx.numel());
......@@ -56,14 +54,14 @@ ego_k_hop_sample_adj_cpu(torch::Tensor rowptr, torch::Tensor col,
}
} else if (replace) {
for (int64_t j = 0; j < num_neighbors; j++) {
w = col_data[row_start + (rand() % row_count)];
w = col_data[row_start + uniform_randint(row_count)];
n_id_set.insert(w);
n_ids.push_back(w);
}
} else {
std::unordered_set<int64_t> perm;
for (int64_t j = row_count - num_neighbors; j < row_count; j++) {
if (!perm.insert(rand() % j).second) {
if (!perm.insert(uniform_randint(j)).second) {
perm.insert(j);
}
}
......
......@@ -105,8 +105,6 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<node_t, vector<int64_t>> &num_samples_dict,
const int64_t num_hops) {
srand(time(NULL) + 1000 * getpid()); // Initialize random seed.
// Create a mapping to convert single string relations to edge type triplets:
unordered_map<rel_t, edge_t> to_edge_type;
for (const auto &kv : colptr_dict) {
......@@ -226,12 +224,10 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
}
}
}
if (rows.size() > 0) {
out_row_dict.insert(rel_type, from_vector<int64_t>(rows));
out_col_dict.insert(rel_type, from_vector<int64_t>(cols));
out_edge_dict.insert(rel_type, from_vector<int64_t>(edges));
}
}
}
// Generate tensor-valued output node dictionary (line 20):
for (const auto &kv : nodes_dict) {
......
......@@ -44,7 +44,7 @@ torch::Tensor partition_cpu(torch::Tensor rowptr, torch::Tensor col,
vwgt = optional_node_weight.value().data_ptr<int64_t>();
int64_t objval = -1;
auto part = torch::empty(nvtxs, rowptr.options());
auto part = torch::empty({nvtxs}, rowptr.options());
auto part_data = part.data_ptr<int64_t>();
if (recursive) {
......@@ -99,7 +99,7 @@ mt_partition_cpu(torch::Tensor rowptr, torch::Tensor col,
mtmetis_pid_type nparts = num_parts;
mtmetis_wgt_type objval = -1;
auto part = torch::empty(nvtxs, rowptr.options());
auto part = torch::empty({nvtxs}, rowptr.options());
mtmetis_pid_type *part_data = (mtmetis_pid_type *)part.data_ptr<int64_t>();
double *opts = mtmetis_init_options();
......
......@@ -10,16 +10,16 @@ using namespace std;
namespace {
typedef phmap::flat_hash_map<pair<int64_t, int64_t>, int64_t> temporarl_edge_dict;
template <bool replace, bool directed>
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
sample(const torch::Tensor &colptr, const torch::Tensor &row,
const torch::Tensor &input_node, const vector<int64_t> num_neighbors) {
srand(time(NULL) + 1000 * getpid()); // Initialize random seed.
// Initialize some data structures for the sampling process:
vector<int64_t> samples;
unordered_map<int64_t, int64_t> to_local_node;
phmap::flat_hash_map<int64_t, int64_t> to_local_node;
auto *colptr_data = colptr.data_ptr<int64_t>();
auto *row_data = row.data_ptr<int64_t>();
......@@ -59,7 +59,7 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
}
} else if (replace) {
for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = col_start + rand() % col_count;
const int64_t offset = col_start + uniform_randint(col_count);
const int64_t &v = row_data[offset];
const auto res = to_local_node.insert({v, samples.size()});
if (res.second)
......@@ -73,7 +73,7 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
} else {
unordered_set<int64_t> rnd_indices;
for (int64_t j = col_count - num_samples; j < col_count; j++) {
int64_t rnd = rand() % j;
int64_t rnd = uniform_randint(j);
if (!rnd_indices.insert(rnd).second) {
rnd = j;
rnd_indices.insert(j);
......@@ -95,7 +95,7 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
}
if (!directed) {
unordered_map<int64_t, int64_t>::iterator iter;
phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
for (int64_t i = 0; i < (int64_t)samples.size(); i++) {
const auto &w = samples[i];
const auto &col_start = colptr_data[w];
......@@ -116,7 +116,20 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
from_vector<int64_t>(cols), from_vector<int64_t>(edges));
}
template <bool replace, bool directed>
inline bool satisfy_time(const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const node_t &src_node_type, int64_t dst_time,
int64_t src_node) {
try {
// Check whether src -> dst obeys the time constraint
const torch::Tensor &src_node_time = node_time_dict.at(src_node_type);
return src_node_time.data_ptr<int64_t>()[src_node] <= dst_time;
} catch (const std::out_of_range& e) {
// If no time is given, fall back to normal sampling
return true;
}
}
template <bool replace, bool directed, bool temporal>
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_sample(const vector<node_t> &node_types,
......@@ -125,24 +138,29 @@ hetero_sample(const vector<node_t> &node_types,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops) {
srand(time(NULL) + 1000 * getpid()); // Initialize random seed.
// Create a mapping to convert single string relations to edge type triplets:
unordered_map<rel_t, edge_t> to_edge_type;
phmap::flat_hash_map<rel_t, edge_t> to_edge_type;
for (const auto &k : edge_types)
to_edge_type[get<0>(k) + "__" + get<1>(k) + "__" + get<2>(k)] = k;
// Initialize some data structures for the sampling process:
unordered_map<node_t, vector<int64_t>> samples_dict;
unordered_map<node_t, unordered_map<int64_t, int64_t>> to_local_node_dict;
phmap::flat_hash_map<node_t, vector<int64_t>> samples_dict;
phmap::flat_hash_map<node_t, vector<pair<int64_t, int64_t>>> temp_samples_dict;
phmap::flat_hash_map<node_t, phmap::flat_hash_map<int64_t, int64_t>> to_local_node_dict;
phmap::flat_hash_map<node_t, temporarl_edge_dict> temp_to_local_node_dict;
phmap::flat_hash_map<node_t, vector<int64_t>> root_time_dict;
for (const auto &node_type : node_types) {
samples_dict[node_type];
temp_samples_dict[node_type];
to_local_node_dict[node_type];
temp_to_local_node_dict[node_type];
root_time_dict[node_type];
}
unordered_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
phmap::flat_hash_map<rel_t, vector<int64_t>> rows_dict, cols_dict, edges_dict;
for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key();
rows_dict[rel_type];
......@@ -156,41 +174,82 @@ hetero_sample(const vector<node_t> &node_types,
const torch::Tensor &input_node = kv.value();
const auto *input_node_data = input_node.data_ptr<int64_t>();
int64_t *node_time_data;
if (temporal) {
const torch::Tensor &node_time = node_time_dict.at(node_type);
node_time_data = node_time.data_ptr<int64_t>();
}
auto &samples = samples_dict.at(node_type);
auto &temp_samples = temp_samples_dict.at(node_type);
auto &to_local_node = to_local_node_dict.at(node_type);
auto &temp_to_local_node = temp_to_local_node_dict.at(node_type);
auto &root_time = root_time_dict.at(node_type);
for (int64_t i = 0; i < input_node.numel(); i++) {
const auto &v = input_node_data[i];
if (temporal) {
temp_samples.push_back({v, i});
temp_to_local_node.insert({{v, i}, i});
} else {
samples.push_back(v);
to_local_node.insert({v, i});
}
if (temporal)
root_time.push_back(node_time_data[v]);
}
}
unordered_map<node_t, pair<int64_t, int64_t>> slice_dict;
phmap::flat_hash_map<node_t, pair<int64_t, int64_t>> slice_dict;
if (temporal) {
for (const auto &kv : temp_samples_dict) {
slice_dict[kv.first] = {0, kv.second.size()};
}
} else {
for (const auto &kv : samples_dict)
slice_dict[kv.first] = {0, kv.second.size()};
}
for (int64_t ell = 0; ell < num_hops; ell++) {
vector<rel_t> all_rel_types;
for (const auto &kv : num_neighbors_dict) {
const auto &rel_type = kv.key();
all_rel_types.push_back(kv.key());
}
std::sort(all_rel_types.begin(), all_rel_types.end());
for (int64_t ell = 0; ell < num_hops; ell++) {
for (const auto &rel_type : all_rel_types) {
const auto &edge_type = to_edge_type[rel_type];
const auto &src_node_type = get<0>(edge_type);
const auto &dst_node_type = get<2>(edge_type);
const auto num_samples = kv.value()[ell];
const auto num_samples = num_neighbors_dict.at(rel_type)[ell];
const auto &dst_samples = samples_dict.at(dst_node_type);
const auto &temp_dst_samples = temp_samples_dict.at(dst_node_type);
auto &src_samples = samples_dict.at(src_node_type);
auto &temp_src_samples = temp_samples_dict.at(src_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
auto &temp_to_local_src_node = temp_to_local_node_dict.at(src_node_type);
const auto *colptr_data = ((torch::Tensor)colptr_dict.at(rel_type)).data_ptr<int64_t>();
const auto *row_data = ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
const torch::Tensor &colptr = colptr_dict.at(rel_type);
const auto *colptr_data = colptr.data_ptr<int64_t>();
const torch::Tensor &row = row_dict.at(rel_type);
const auto *row_data = row.data_ptr<int64_t>();
auto &rows = rows_dict.at(rel_type);
auto &cols = cols_dict.at(rel_type);
auto &edges = edges_dict.at(rel_type);
// For temporal sampling, sampled nodes cannot have a timestamp greater
// than the timestamp of the root nodes:
const auto &dst_root_time = root_time_dict.at(dst_node_type);
auto &src_root_time = root_time_dict.at(src_node_type);
const auto &begin = slice_dict.at(dst_node_type).first;
const auto &end = slice_dict.at(dst_node_type).second;
for (int64_t i = begin; i < end; i++) {
const auto &w = dst_samples[i];
const auto &w = temporal ? temp_dst_samples[i].first : dst_samples[i];
const int64_t root_w = temporal ? temp_dst_samples[i].second : -1;
int64_t dst_time = 0;
if (temporal)
dst_time = dst_root_time[i];
const auto &col_start = colptr_data[w];
const auto &col_end = colptr_data[w + 1];
const auto col_count = col_end - col_start;
......@@ -199,8 +258,26 @@ hetero_sample(const vector<node_t> &node_types,
continue;
if ((num_samples < 0) || (!replace && (num_samples >= col_count))) {
// Select all neighbors:
for (int64_t offset = col_start; offset < col_end; offset++) {
const int64_t &v = row_data[offset];
if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
// force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True
// for temporal case
// to_local_src_node is not used for temporal / directed case
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time);
}
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
......@@ -210,10 +287,30 @@ hetero_sample(const vector<node_t> &node_types,
edges.push_back(offset);
}
}
}
} else if (replace) {
for (int64_t j = 0; j < num_samples; j++) {
const int64_t offset = col_start + rand() % col_count;
// Sample with replacement:
int64_t num_neighbors = 0;
while (num_neighbors < num_samples) {
const int64_t offset = col_start + uniform_randint(col_count);
const int64_t &v = row_data[offset];
if (temporal) {
// TODO Infinity loop if no neighbor satisfies time constraint:
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
// force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True
// for temporal case
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time);
}
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
......@@ -223,16 +320,35 @@ hetero_sample(const vector<node_t> &node_types,
edges.push_back(offset);
}
}
num_neighbors += 1;
}
} else {
// Sample without replacement:
unordered_set<int64_t> rnd_indices;
for (int64_t j = col_count - num_samples; j < col_count; j++) {
int64_t rnd = rand() % j;
int64_t rnd = uniform_randint(j);
if (!rnd_indices.insert(rnd).second) {
rnd = j;
rnd_indices.insert(j);
}
const int64_t offset = col_start + rnd;
const int64_t &v = row_data[offset];
if (temporal) {
if (!satisfy_time(node_time_dict, src_node_type, dst_time, v))
continue;
// force disjoint of computation tree based on source batch idx.
// note that the sampling always needs to have directed=True
// for temporal case
const auto res = temp_to_local_src_node.insert({{v, root_w}, (int64_t)temp_src_samples.size()});
if (res.second) {
temp_src_samples.push_back({v, root_w});
src_root_time.push_back(dst_time);
}
cols.push_back(i);
rows.push_back(res.first->second);
edges.push_back(offset);
} else {
const auto res = to_local_src_node.insert({v, src_samples.size()});
if (res.second)
src_samples.push_back(v);
......@@ -245,14 +361,22 @@ hetero_sample(const vector<node_t> &node_types,
}
}
}
}
for (const auto &kv : samples_dict) {
if (temporal) {
for (const auto &kv : temp_samples_dict) {
slice_dict[kv.first] = {slice_dict.at(kv.first).second, kv.second.size()};
}
} else {
for (const auto &kv : samples_dict)
slice_dict[kv.first] = {slice_dict.at(kv.first).second, kv.second.size()};
}
}
// Temporal sample disable undirected
assert(!(temporal && !directed));
if (!directed) { // Construct the subgraph among the sampled nodes:
unordered_map<int64_t, int64_t>::iterator iter;
phmap::flat_hash_map<int64_t, int64_t>::iterator iter;
for (const auto &kv : colptr_dict) {
const auto &rel_type = kv.key();
const auto &edge_type = to_edge_type[rel_type];
......@@ -262,7 +386,8 @@ hetero_sample(const vector<node_t> &node_types,
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
const auto *colptr_data = ((torch::Tensor)kv.value()).data_ptr<int64_t>();
const auto *row_data = ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
const auto *row_data =
((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
auto &rows = rows_dict.at(rel_type);
auto &cols = cols_dict.at(rel_type);
......@@ -285,6 +410,18 @@ hetero_sample(const vector<node_t> &node_types,
}
}
// Construct samples dictionary from temporal sample dictionary.
if (temporal) {
for (const auto &kv : temp_samples_dict) {
const auto &node_type = kv.first;
const auto &samples = kv.second;
samples_dict[node_type].reserve(samples.size());
for (const auto &v : samples) {
samples_dict[node_type].push_back(v.first);
}
}
}
return make_tuple(from_vector<node_t, int64_t>(samples_dict),
from_vector<rel_t, int64_t>(rows_dict),
from_vector<rel_t, int64_t>(cols_dict),
......@@ -320,21 +457,51 @@ hetero_neighbor_sample_cpu(
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const int64_t num_hops, const bool replace, const bool directed) {
c10::Dict<node_t, torch::Tensor> node_time_dict; // Empty dictionary.
if (replace && directed) {
return hetero_sample<true, true>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
return hetero_sample<true, true, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
} else if (replace && !directed) {
return hetero_sample<true, false>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
return hetero_sample<true, false, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
} else if (!replace && directed) {
return hetero_sample<false, true>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
return hetero_sample<false, true, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
} else {
return hetero_sample<false, false, false>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
}
}
tuple<c10::Dict<node_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>,
c10::Dict<rel_t, torch::Tensor>, c10::Dict<rel_t, torch::Tensor>>
hetero_temporal_neighbor_sample_cpu(
const vector<node_t> &node_types, const vector<edge_t> &edge_types,
const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const c10::Dict<rel_t, torch::Tensor> &row_dict,
const c10::Dict<node_t, torch::Tensor> &input_node_dict,
const c10::Dict<rel_t, vector<int64_t>> &num_neighbors_dict,
const c10::Dict<node_t, torch::Tensor> &node_time_dict,
const int64_t num_hops, const bool replace, const bool directed) {
AT_ASSERTM(directed, "Temporal sampling requires 'directed' sampling");
if (replace) {
// We assume that directed = True for temporal sampling
// The current implementation uses disjoint computation trees
// to tackle the case of the same node sampled having different
// root time constraint.
// In future, we could extend to directed = False case,
// allowing additional edges within each computation tree.
return hetero_sample<true, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
} else {
return hetero_sample<false, false>(node_types, edge_types, colptr_dict,
row_dict, input_node_dict,
num_neighbors_dict, num_hops);
return hetero_sample<false, true, true>(
node_types, edge_types, colptr_dict, row_dict, input_node_dict,
num_neighbors_dict, node_time_dict, num_hops);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment