Commit 99db5b80 authored by rusty1s's avatar rusty1s
Browse files

benchmark fixes

parent 82838e1d
...@@ -7,9 +7,7 @@ import wget ...@@ -7,9 +7,7 @@ import wget
import torch import torch
from scipy.io import loadmat from scipy.io import loadmat
import torch_scatter from torch_scatter import scatter, segment_coo, segment_csr
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr
short_rows = [ short_rows = [
('DIMACS10', 'citationCiteseer'), ('DIMACS10', 'citationCiteseer'),
...@@ -47,34 +45,30 @@ def correctness(dataset): ...@@ -47,34 +45,30 @@ def correctness(dataset):
x = torch.randn((row.size(0), size), device=args.device) x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x x = x.squeeze(-1) if size == 1 else x
out1 = scatter_add(x, row, dim=0, dim_size=dim_size) out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='add') out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
out3 = segment_csr(x, rowptr, reduce='add') out3 = segment_csr(x, rowptr, reduce='add')
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
out1 = scatter_mean(x, row, dim=0, dim_size=dim_size) out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean') out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
out3 = segment_csr(x, rowptr, reduce='mean') out3 = segment_csr(x, rowptr, reduce='mean')
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
x = x.abs_().mul_(-1) out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min')
out2 = segment_coo(x, row, reduce='min')
out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1)) out3 = segment_csr(x, rowptr, reduce='min')
out2, _ = segment_coo(x, row, reduce='min')
out3, _ = segment_csr(x, rowptr, reduce='min')
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
x = x.abs_() out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max')
out2 = segment_coo(x, row, reduce='max')
out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1)) out3 = segment_csr(x, rowptr, reduce='max')
out2, _ = segment_coo(x, row, reduce='max')
out3, _ = segment_csr(x, rowptr, reduce='max')
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
...@@ -117,17 +111,15 @@ def timing(dataset): ...@@ -117,17 +111,15 @@ def timing(dataset):
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
row_perm = row[torch.randperm(row.size(0))] row2 = row[torch.randperm(row.size(0))]
dim_size = rowptr.size(0) - 1 dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size avg_row_len = row.size(0) / dim_size
def sca_row(x): def sca_row(x):
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}') return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)
return op(x, row, dim=0, dim_size=dim_size)
def sca_col(x): def sca_col(x):
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}') return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)
return op(x, row_perm, dim=0, dim_size=dim_size)
def seg_coo(x): def seg_coo(x):
return segment_coo(x, row, reduce=args.reduce) return segment_coo(x, row, reduce=args.reduce)
...@@ -205,11 +197,10 @@ def timing(dataset): ...@@ -205,11 +197,10 @@ def timing(dataset):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True, parser.add_argument('--reduce', type=str, required=True,
choices=['sum', 'mean', 'min', 'max']) choices=['sum', 'add', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true') parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args() args = parser.parse_args()
args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce
iters = 1 if args.device == 'cpu' else 20 iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512] sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes sizes = sizes[:3] if args.device == 'cpu' else sizes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment