Commit aae9e125 authored by rusty1s's avatar rusty1s
Browse files

update benchmark

parent d7f9176e
# flake8: noqa
import time import time
import itertools import itertools
import argparse
import torch import torch
from scipy.io import loadmat from scipy.io import loadmat
from torch_scatter import gather_coo, gather_csr from torch_scatter import gather_coo, gather_csr
from scatter_segment import iters, device, sizes from scatter_segment import iters, sizes
from scatter_segment import short_rows, long_rows, download, bold from scatter_segment import short_rows, long_rows, download, bold
...@@ -14,13 +17,13 @@ from scatter_segment import short_rows, long_rows, download, bold ...@@ -14,13 +17,13 @@ from scatter_segment import short_rows, long_rows, download, bold
def correctness(dataset): def correctness(dataset):
group, name = dataset group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long) rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1 dim_size = rowptr.size(0) - 1
for size in sizes[1:]: for size in sizes[1:]:
try: try:
x = torch.randn((dim_size, size), device=device) x = torch.randn((dim_size, size), device=args.device)
x = x.squeeze(-1) if size == 1 else x x = x.squeeze(-1) if size == 1 else x
out1 = x.index_select(0, row) out1 = x.index_select(0, row)
...@@ -34,75 +37,48 @@ def correctness(dataset): ...@@ -34,75 +37,48 @@ def correctness(dataset):
@torch.no_grad() @torch.no_grad()
def timing(dataset): def time_func(func, x):
group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long)
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size
t1, t2, t3, t4 = [], [], [], []
for size in sizes:
try:
x = torch.randn((dim_size, size), device=device)
row_expand = row.view(-1, 1).expand(-1, x.size(-1))
x = x.squeeze(-1) if size == 1 else x
row_expand = row_expand.squeeze(-1) if size == 1 else row_expand
try: try:
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.perf_counter() t = time.perf_counter()
for _ in range(iters): for _ in range(iters):
out = x.index_select(0, row) func(x)
del out
torch.cuda.synchronize() torch.cuda.synchronize()
t1.append(time.perf_counter() - t) return time.perf_counter() - t
except RuntimeError: except RuntimeError:
torch.cuda.empty_cache() torch.cuda.empty_cache()
t1.append(float('inf')) return float('inf')
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = x.gather(0, row_expand)
del out
torch.cuda.synchronize()
t2.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t2.append(float('inf'))
try: @torch.no_grad()
torch.cuda.synchronize() def timing(dataset):
t = time.perf_counter() group, name = dataset
for _ in range(iters): mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
out = gather_coo(x, row) rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
del out row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
torch.cuda.synchronize() dim_size = rowptr.size(0) - 1
t3.append(time.perf_counter() - t) avg_row_len = row.size(0) / dim_size
except RuntimeError:
torch.cuda.empty_cache() select = lambda x: x.index_select(0, row)
t3.append(float('inf')) gather = lambda x: x.gather(0, row.view(-1, 1).expand(-1, x.size(1)))
gat_coo = lambda x: gather_coo(x, row)
gat_csr = lambda x: gather_csr(x, rowptr)
t1, t2, t3, t4 = [], [], [], []
for size in sizes:
try: try:
torch.cuda.synchronize() x = torch.randn((dim_size, size), device=args.device)
t = time.perf_counter()
for _ in range(iters): t1 += [time_func(select, x)]
out = gather_csr(x, rowptr) t2 += [time_func(gather, x)]
del out t3 += [time_func(gat_coo, x)]
torch.cuda.synchronize() t4 += [time_func(gat_csr, x)]
t4.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t4.append(float('inf'))
del x del x
except RuntimeError: except RuntimeError:
torch.cuda.empty_cache() torch.cuda.empty_cache()
for t in (t1, t2, t3): for t in (t1, t2, t3, t4):
t.append(float('inf')) t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4]) ts = torch.tensor([t1, t2, t3, t4])
...@@ -125,8 +101,12 @@ def timing(dataset): ...@@ -125,8 +101,12 @@ def timing(dataset):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
for _ in range(10): # Warmup. for _ in range(10): # Warmup.
torch.randn(100, 100, device=device).sum() torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows): for dataset in itertools.chain(short_rows, long_rows):
download(dataset) download(dataset)
correctness(dataset) correctness(dataset)
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import time import time
import os.path as osp import os.path as osp
import itertools import itertools
import argparse
import argparse
import wget import wget
import torch import torch
from scipy.io import loadmat from scipy.io import loadmat
...@@ -13,12 +13,6 @@ import torch_scatter ...@@ -13,12 +13,6 @@ import torch_scatter
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr from torch_scatter import segment_coo, segment_csr
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
iters = 20 iters = 20
sizes = [1, 16, 32, 64, 128, 256, 512] sizes = [1, 16, 32, 64, 128, 256, 512]
...@@ -94,6 +88,7 @@ def correctness(dataset): ...@@ -94,6 +88,7 @@ def correctness(dataset):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@torch.no_grad()
def time_func(func, x): def time_func(func, x):
try: try:
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -184,6 +179,13 @@ def timing(dataset): ...@@ -184,6 +179,13 @@ def timing(dataset):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True,
choices=['add', 'mean', 'min', 'max'])
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
for _ in range(10): # Warmup. for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum() torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows): for dataset in itertools.chain(short_rows, long_rows):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment