Unverified Commit 003abd58 authored by Matthias Fey's avatar Matthias Fey Committed by GitHub
Browse files

Improve `scatter` performance + upgrade to `torch-scatter==2.1.0` (#338)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* typo

* typo
parent c1285089
......@@ -13,7 +13,7 @@ jobs:
# We have trouble building for Windows - drop for now.
os: [ubuntu-18.04, macos-10.15] # windows-2019
python-version: ['3.7', '3.8', '3.9', '3.10']
torch-version: [1.13.0] # [1.12.0, 1.13.0]
torch-version: [1.12.0, 1.13.0]
cuda-version: ['cpu', 'cu102', 'cu113', 'cu116', 'cu117']
exclude:
- torch-version: 1.12.0
......@@ -32,8 +32,6 @@ jobs:
cuda-version: 'cu117'
- os: windows-2019
cuda-version: 'cu102'
- os: windows-2019 # Complains about CUDA mismatch.
python-version: '3.7'
steps:
- uses: actions/checkout@v2
......
......@@ -63,7 +63,11 @@ jobs:
if: ${{ runner.os != 'macOS' }}
run: |
VERSION=`sed -n "s/^__version__ = '\(.*\)'/\1/p" torch_scatter/__init__.py`
sed -i "s/$VERSION/$VERSION+${{ matrix.cuda-version }}/" torch_scatter/__init__.py
TORCH_VERSION=`echo "pt${{ matrix.torch-version }}" | sed "s/..$//" | sed "s/\.//g"`
CUDA_VERSION=`echo ${{ matrix.cuda-version }}`
echo "New version name: $VERSION+$TORCH_VERSION$CUDA_VERSION"
sed -i "s/$VERSION/$VERSION+$TORCH_VERSION$CUDA_VERSION/" setup.py
sed -i "s/$VERSION/$VERSION+$TORCH_VERSION$CUDA_VERSION/" torch_scatter/__init__.py
shell:
bash
......
cmake_minimum_required(VERSION 3.0)
project(torchscatter)
set(CMAKE_CXX_STANDARD 14)
set(TORCHSCATTER_VERSION 2.0.9)
set(TORCHSCATTER_VERSION 2.1.0)
option(WITH_CUDA "Enable CUDA support" OFF)
option(WITH_PYTHON "Link to Python when building" ON)
......
package:
name: pytorch-scatter
version: 2.0.9
version: 2.1.0
source:
path: ../..
......
......@@ -7,7 +7,7 @@
#include "reducer.cuh"
#include "utils.cuh"
#define THREADS 1024
#define THREADS 256
#define BLOCKS(N) (N + THREADS - 1) / THREADS
template <typename scalar_t, ReductionType REDUCE>
......
......@@ -11,7 +11,7 @@ from torch.__config__ import parallel_info
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension,
CUDAExtension)
__version__ = '2.0.9'
__version__ = '2.1.0'
URL = 'https://github.com/rusty1s/pytorch_scatter'
WITH_CUDA = False
......
......@@ -4,7 +4,7 @@ import os.path as osp
import torch
__version__ = '2.0.9'
__version__ = '2.1.0'
for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']:
cuda_spec = importlib.machinery.PathFinder().find_spec(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment