Commit aae9e125 authored by rusty1s's avatar rusty1s
Browse files

update benchmark

parent d7f9176e
# flake8: noqa
import time
import itertools
import argparse
import torch
from scipy.io import loadmat
from torch_scatter import gather_coo, gather_csr
from scatter_segment import iters, device, sizes
from scatter_segment import iters, sizes
from scatter_segment import short_rows, long_rows, download, bold
......@@ -14,13 +17,13 @@ from scatter_segment import short_rows, long_rows, download, bold
def correctness(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long)
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1
for size in sizes[1:]:
try:
x = torch.randn((dim_size, size), device=device)
x = torch.randn((dim_size, size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
out1 = x.index_select(0, row)
......@@ -33,76 +36,49 @@ def correctness(dataset):
torch.cuda.empty_cache()
@torch.no_grad()
def time_func(func, x):
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
func(x)
torch.cuda.synchronize()
return time.perf_counter() - t
except RuntimeError:
torch.cuda.empty_cache()
return float('inf')
@torch.no_grad()
def timing(dataset):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long)
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size
select = lambda x: x.index_select(0, row)
gather = lambda x: x.gather(0, row.view(-1, 1).expand(-1, x.size(1)))
gat_coo = lambda x: gather_coo(x, row)
gat_csr = lambda x: gather_csr(x, rowptr)
t1, t2, t3, t4 = [], [], [], []
for size in sizes:
try:
x = torch.randn((dim_size, size), device=device)
row_expand = row.view(-1, 1).expand(-1, x.size(-1))
x = x.squeeze(-1) if size == 1 else x
row_expand = row_expand.squeeze(-1) if size == 1 else row_expand
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = x.index_select(0, row)
del out
torch.cuda.synchronize()
t1.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t1.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = x.gather(0, row_expand)
del out
torch.cuda.synchronize()
t2.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t2.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = gather_coo(x, row)
del out
torch.cuda.synchronize()
t3.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t3.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = gather_csr(x, rowptr)
del out
torch.cuda.synchronize()
t4.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t4.append(float('inf'))
x = torch.randn((dim_size, size), device=args.device)
t1 += [time_func(select, x)]
t2 += [time_func(gather, x)]
t3 += [time_func(gat_coo, x)]
t4 += [time_func(gat_csr, x)]
del x
except RuntimeError:
torch.cuda.empty_cache()
for t in (t1, t2, t3):
for t in (t1, t2, t3, t4):
t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4])
......@@ -125,8 +101,12 @@ def timing(dataset):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
for _ in range(10): # Warmup.
torch.randn(100, 100, device=device).sum()
torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
......
......@@ -3,8 +3,8 @@
import time
import os.path as osp
import itertools
import argparse
import argparse
import wget
import torch
from scipy.io import loadmat
......@@ -13,12 +13,6 @@ import torch_scatter
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
iters = 20
sizes = [1, 16, 32, 64, 128, 256, 512]
......@@ -94,6 +88,7 @@ def correctness(dataset):
torch.cuda.empty_cache()
@torch.no_grad()
def time_func(func, x):
try:
torch.cuda.synchronize()
......@@ -184,6 +179,13 @@ def timing(dataset):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True,
choices=['add', 'mean', 'min', 'max'])
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment