Commit ea94e546 authored by rusty1s's avatar rusty1s
Browse files

cpu boilerplate

parent d824c8be
...@@ -30,13 +30,16 @@ def correctness(dataset): ...@@ -30,13 +30,16 @@ def correctness(dataset):
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
except RuntimeError: except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache() torch.cuda.empty_cache()
def time_func(func, x): def time_func(func, x):
try: try:
torch.cuda.synchronize() if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.perf_counter() t = time.perf_counter()
if not args.with_backward: if not args.with_backward:
...@@ -49,9 +52,12 @@ def time_func(func, x): ...@@ -49,9 +52,12 @@ def time_func(func, x):
out = func(x) out = func(x)
torch.autograd.grad(out, x, out, only_inputs=True) torch.autograd.grad(out, x, out, only_inputs=True)
torch.cuda.synchronize() if torch.cuda.is_available():
torch.cuda.synchronize()
return time.perf_counter() - t return time.perf_counter() - t
except RuntimeError: except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache() torch.cuda.empty_cache()
return float('inf') return float('inf')
...@@ -88,7 +94,9 @@ def timing(dataset): ...@@ -88,7 +94,9 @@ def timing(dataset):
del x del x
except RuntimeError: except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache() torch.cuda.empty_cache()
for t in (t1, t2, t3, t4): for t in (t1, t2, t3, t4):
t.append(float('inf')) t.append(float('inf'))
......
...@@ -82,13 +82,16 @@ def correctness(dataset): ...@@ -82,13 +82,16 @@ def correctness(dataset):
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
except RuntimeError: except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache() torch.cuda.empty_cache()
def time_func(func, x): def time_func(func, x):
try: try:
torch.cuda.synchronize() if torch.cuda.is_available():
torch.cuda.synchronize()
t = time.perf_counter() t = time.perf_counter()
if not args.with_backward: if not args.with_backward:
...@@ -102,9 +105,12 @@ def time_func(func, x): ...@@ -102,9 +105,12 @@ def time_func(func, x):
out = out[0] if isinstance(out, tuple) else out out = out[0] if isinstance(out, tuple) else out
torch.autograd.grad(out, x, out, only_inputs=True) torch.autograd.grad(out, x, out, only_inputs=True)
torch.cuda.synchronize() if torch.cuda.is_available():
torch.cuda.synchronize()
return time.perf_counter() - t return time.perf_counter() - t
except RuntimeError: except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache() torch.cuda.empty_cache()
return float('inf') return float('inf')
...@@ -152,7 +158,9 @@ def timing(dataset): ...@@ -152,7 +158,9 @@ def timing(dataset):
del x del x
except RuntimeError: except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache() torch.cuda.empty_cache()
for t in (t1, t2, t3, t4): for t in (t1, t2, t3, t4):
t.append(float('inf')) t.append(float('inf'))
...@@ -167,7 +175,9 @@ def timing(dataset): ...@@ -167,7 +175,9 @@ def timing(dataset):
del x del x
except RuntimeError: except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache() torch.cuda.empty_cache()
for t in (t5, t6): for t in (t5, t6):
t.append(float('inf')) t.append(float('inf'))
...@@ -197,8 +207,11 @@ def timing(dataset): ...@@ -197,8 +207,11 @@ def timing(dataset):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True, parser.add_argument(
choices=['add', 'mean', 'min', 'max']) '--reduce',
type=str,
required=True,
choices=['add', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true') parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args() args = parser.parse_args()
......
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (out_opt.has_value())
CHECK_CPU(out_opt.value());
AT_ASSERTM(false, "Not yet implemented");
return src;
}
at::Tensor gather_coo(at::Tensor src, at::Tensor index,
at::optional<at::Tensor> out_opt) {
CHECK_CPU(src);
CHECK_CPU(index);
if (out_opt.has_value())
CHECK_CPU(out_opt.value());
AT_ASSERTM(false, "Not yet implemented");
return src;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gather_csr", &gather_csr, "Gather CSR (CPU)");
m.def("gather_coo", &gather_coo, "Gather COO (CPU)");
}
#include <torch/extension.h>
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(indptr);
if (out_opt.has_value())
CHECK_CPU(out_opt.value());
AT_ASSERTM(false, "Not yet implemented");
return std::make_tuple(src, at::nullopt);
}
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
CHECK_CPU(out);
AT_ASSERTM(false, "Not yet implemented");
return std::make_tuple(src, at::nullopt);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("segment_csr", &segment_csr, "Segment CSR (CPU)");
m.def("segment_coo", &segment_coo, "Segment COO (CPU)");
}
...@@ -25,8 +25,9 @@ cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} ...@@ -25,8 +25,9 @@ cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
ext_modules = [] ext_modules = []
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))] exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))]
ext_modules += [ ext_modules += [
CppExtension(f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'], CppExtension(
extra_compile_args=cxx_extra_compile_args) for ext in exts f'torch_scatter.{ext}_cpu', [f'cpu/{ext}.cpp'],
extra_compile_args=cxx_extra_compile_args) for ext in exts
] ]
if CUDA_HOME is not None and USE_GPU: if CUDA_HOME is not None and USE_GPU:
...@@ -34,7 +35,8 @@ if CUDA_HOME is not None and USE_GPU: ...@@ -34,7 +35,8 @@ if CUDA_HOME is not None and USE_GPU:
ext_modules += [ ext_modules += [
CUDAExtension( CUDAExtension(
f'torch_scatter.{ext}_cuda', f'torch_scatter.{ext}_cuda',
[f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'], extra_compile_args={ [f'cuda/{ext}.cpp', f'cuda/{ext}_kernel.cu'],
extra_compile_args={
'cxx': cxx_extra_compile_args, 'cxx': cxx_extra_compile_args,
'nvcc': nvcc_extra_compile_args, 'nvcc': nvcc_extra_compile_args,
}) for ext in exts }) for ext in exts
......
import torch import torch
from torch_scatter import segment_cpu, gather_cpu
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch_scatter import gather_cuda, segment_cuda from torch_scatter import gather_cuda, segment_cuda
gat = lambda is_cuda: gather_cuda if is_cuda else gather_cpu # noqa
seg = lambda is_cuda: segment_cuda if is_cuda else segment_cpu # noqa
class GatherCOO(torch.autograd.Function): class GatherCOO(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -12,7 +17,7 @@ class GatherCOO(torch.autograd.Function): ...@@ -12,7 +17,7 @@ class GatherCOO(torch.autograd.Function):
ctx.src_size = list(src.size()) ctx.src_size = list(src.size())
ctx.save_for_backward(index) ctx.save_for_backward(index)
return gather_cuda.gather_coo(src, index, out) return gat(src.is_cuda).gather_coo(src, index, out)
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
...@@ -20,7 +25,7 @@ class GatherCOO(torch.autograd.Function): ...@@ -20,7 +25,7 @@ class GatherCOO(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_src, _ = segment_cuda.segment_coo( grad_src, _ = seg(grad_out.is_cuda).segment_coo(
grad_out, index, grad_out.new_zeros(src_size), 'add') grad_out, index, grad_out.new_zeros(src_size), 'add')
return grad_src, None, None return grad_src, None, None
...@@ -34,7 +39,7 @@ class GatherCSR(torch.autograd.Function): ...@@ -34,7 +39,7 @@ class GatherCSR(torch.autograd.Function):
ctx.src_size = list(src.size()) ctx.src_size = list(src.size())
ctx.save_for_backward(indptr) ctx.save_for_backward(indptr)
return gather_cuda.gather_csr(src, indptr, out) return gat(src.is_cuda).gather_csr(src, indptr, out)
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
...@@ -42,7 +47,7 @@ class GatherCSR(torch.autograd.Function): ...@@ -42,7 +47,7 @@ class GatherCSR(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_src, _ = segment_cuda.segment_csr( grad_src, _ = seg(grad_out.is_cuda).segment_csr(
grad_out, indptr, grad_out.new_empty(src_size), 'add') grad_out, indptr, grad_out.new_empty(src_size), 'add')
return grad_src, None, None return grad_src, None, None
......
import torch import torch
from torch_scatter import segment_cpu, gather_cpu
from torch_scatter.helpers import min_value, max_value from torch_scatter.helpers import min_value, max_value
if torch.cuda.is_available(): if torch.cuda.is_available():
from torch_scatter import segment_cuda, gather_cuda from torch_scatter import segment_cuda, gather_cuda
seg = lambda is_cuda: segment_cuda if is_cuda else segment_cpu # noqa
gat = lambda is_cuda: gather_cuda if is_cuda else gather_cpu # noqa
class SegmentCOO(torch.autograd.Function): class SegmentCOO(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -28,7 +32,7 @@ class SegmentCOO(torch.autograd.Function): ...@@ -28,7 +32,7 @@ class SegmentCOO(torch.autograd.Function):
out = src.new_full(size, fill_value) out = src.new_full(size, fill_value)
out, arg_out = segment_cuda.segment_coo(src, index, out, reduce) out, arg_out = seg(src.is_cuda).segment_coo(src, index, out, reduce)
if fill_value != 0: if fill_value != 0:
out.masked_fill_(out == fill_value, 0) out.masked_fill_(out == fill_value, 0)
...@@ -47,13 +51,13 @@ class SegmentCOO(torch.autograd.Function): ...@@ -47,13 +51,13 @@ class SegmentCOO(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
if ctx.reduce == 'add': if ctx.reduce == 'add':
grad_src = gather_cuda.gather_coo(grad_out, index, grad_src = gat(grad_out).gather_coo(
grad_out.new_empty(src_size)) grad_out, index, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean': elif ctx.reduce == 'mean':
grad_src = gather_cuda.gather_coo(grad_out, index, grad_src = gat(grad_out).gather_coo(
grad_out.new_empty(src_size)) grad_out, index, grad_out.new_empty(src_size))
count = arg_out count = arg_out
count = gather_cuda.gather_coo( count = gat(grad_out.is_cuda).gather_coo(
count, index, count.new_empty(src_size[:index.dim()])) count, index, count.new_empty(src_size[:index.dim()]))
for _ in range(grad_out.dim() - index.dim()): for _ in range(grad_out.dim() - index.dim()):
count = count.unsqueeze(-1) count = count.unsqueeze(-1)
...@@ -78,7 +82,7 @@ class SegmentCSR(torch.autograd.Function): ...@@ -78,7 +82,7 @@ class SegmentCSR(torch.autograd.Function):
ctx.reduce = reduce ctx.reduce = reduce
ctx.src_size = list(src.size()) ctx.src_size = list(src.size())
out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce) out, arg_out = seg(src.is_cuda).segment_csr(src, indptr, out, reduce)
ctx.save_for_backward(indptr, arg_out) ctx.save_for_backward(indptr, arg_out)
return out if arg_out is None else (out, arg_out) return out if arg_out is None else (out, arg_out)
...@@ -89,15 +93,15 @@ class SegmentCSR(torch.autograd.Function): ...@@ -89,15 +93,15 @@ class SegmentCSR(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
if ctx.reduce == 'add': if ctx.reduce == 'add':
grad_src = gather_cuda.gather_csr(grad_out, indptr, grad_src = gat(grad_out.is_cuda).gather_csr(
grad_out.new_empty(src_size)) grad_out, indptr, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean': elif ctx.reduce == 'mean':
grad_src = gather_cuda.gather_csr(grad_out, indptr, grad_src = gat(grad_out.is_cuda).gather_csr(
grad_out.new_empty(src_size)) grad_out, indptr, grad_out.new_empty(src_size))
indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1) indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1)
indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1) indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1)
count = (indptr2 - indptr1).to(grad_src.dtype) count = (indptr2 - indptr1).to(grad_src.dtype)
count = gather_cuda.gather_csr( count = gat(grad_out.is_cuda).gather_csr(
count, indptr, count.new_empty(src_size[:indptr.dim()])) count, indptr, count.new_empty(src_size[:indptr.dim()]))
for _ in range(grad_out.dim() - indptr.dim()): for _ in range(grad_out.dim() - indptr.dim()):
count = count.unsqueeze(-1) count = count.unsqueeze(-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment