Commit 99db5b80 authored by rusty1s's avatar rusty1s
Browse files

benchmark fixes

parent 82838e1d
......@@ -7,9 +7,7 @@ import wget
import torch
from scipy.io import loadmat
import torch_scatter
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr
from torch_scatter import scatter, segment_coo, segment_csr
short_rows = [
('DIMACS10', 'citationCiteseer'),
......@@ -47,34 +45,30 @@ def correctness(dataset):
x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
out3 = segment_csr(x, rowptr, reduce='add')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
out1 = scatter_mean(x, row, dim=0, dim_size=dim_size)
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
out3 = segment_csr(x, rowptr, reduce='mean')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
x = x.abs_().mul_(-1)
out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1))
out2, _ = segment_coo(x, row, reduce='min')
out3, _ = segment_csr(x, rowptr, reduce='min')
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min')
out2 = segment_coo(x, row, reduce='min')
out3 = segment_csr(x, rowptr, reduce='min')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
x = x.abs_()
out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1))
out2, _ = segment_coo(x, row, reduce='max')
out3, _ = segment_csr(x, rowptr, reduce='max')
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max')
out2 = segment_coo(x, row, reduce='max')
out3 = segment_csr(x, rowptr, reduce='max')
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
......@@ -117,17 +111,15 @@ def timing(dataset):
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
row_perm = row[torch.randperm(row.size(0))]
row2 = row[torch.randperm(row.size(0))]
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size
def sca_row(x):
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
return op(x, row, dim=0, dim_size=dim_size)
return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)
def sca_col(x):
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
return op(x, row_perm, dim=0, dim_size=dim_size)
return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)
def seg_coo(x):
return segment_coo(x, row, reduce=args.reduce)
......@@ -205,11 +197,10 @@ def timing(dataset):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True,
choices=['sum', 'mean', 'min', 'max'])
choices=['sum', 'add', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce
iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment