Commit fe79c969 authored by rusty1s's avatar rusty1s
Browse files

typo

parent b4a9e5d5
......@@ -35,9 +35,15 @@ def get_extensions():
for main in main_files:
name = main.split(os.sep)[-1][:-4]
sources = [main, osp.join(extensions_dir, 'cpu', name + '_cpu.cpp')]
if WITH_CUDA:
sources += [osp.join(extensions_dir, 'cuda', name + '_cuda.cu')]
sources = [main]
path = osp.join(extensions_dir, 'cpu', name + '_cpu.cpp')
if osp.exists(path):
sources += [path]
path = osp.join(extensions_dir, 'cuda', name + '_cuda.cpp')
if WITH_CUDA and osp.exists(path):
sources += [path]
extension = Extension(
'torch_scatter._' + name,
......
import os.path as osp
import torch
from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min,
scatter_max, scatter)
from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
......@@ -9,6 +13,10 @@ from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
from .composite import (scatter_std, scatter_logsumexp, scatter_softmax,
scatter_log_softmax)
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_version.so'))
_version = torch.ops.torch_scatter.cuda_version()
__version__ = '2.0.3'
__all__ = [
......
import warnings
import os.path as osp
from typing import Optional, Tuple
......@@ -6,21 +5,8 @@ import torch
from .utils import broadcast
try:
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_scatter.so'))
except OSError:
warnings.warn('Failed to load `scatter` binaries.')
def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
dim: int, out: Optional[torch.Tensor],
dim_size: Optional[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError
return src, index
torch.ops.torch_scatter.scatter_min = scatter_with_arg_placeholder
torch.ops.torch_scatter.scatter_max = scatter_with_arg_placeholder
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_scatter.so'))
@torch.jit.script
......
import warnings
import os.path as osp
from typing import Optional, Tuple
import torch
try:
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_segment_coo.so'))
except OSError:
warnings.warn('Failed to load `segment_coo` binaries.')
def segment_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor],
dim_size: Optional[int]) -> torch.Tensor:
raise ImportError
return src
def segment_coo_with_arg_placeholder(
src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor],
dim_size: Optional[int]) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError
return src, index
def gather_coo_placeholder(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError
return src
torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder
torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder
torch.ops.torch_scatter.segment_min_coo = segment_coo_with_arg_placeholder
torch.ops.torch_scatter.segment_max_coo = segment_coo_with_arg_placeholder
torch.ops.torch_scatter.gather_coo = gather_coo_placeholder
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_segment_coo.so'))
@torch.jit.script
......
import warnings
import os.path as osp
from typing import Optional, Tuple
import torch
try:
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_segment_csr.so'))
except OSError:
warnings.warn('Failed to load `segment_csr` binaries.')
def segment_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError
return src
def segment_csr_with_arg_placeholder(
src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
raise ImportError
return src, indptr
def gather_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor,
out: Optional[torch.Tensor]) -> torch.Tensor:
raise ImportError
return src
torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder
torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder
torch.ops.torch_scatter.segment_min_csr = segment_csr_with_arg_placeholder
torch.ops.torch_scatter.segment_max_csr = segment_csr_with_arg_placeholder
torch.ops.torch_scatter.gather_csr = gather_csr_placeholder
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_segment_csr.so'))
@torch.jit.script
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment