Commit aae9e125 authored by rusty1s's avatar rusty1s
Browse files

update benchmark

parent d7f9176e
# flake8: noqa
import time import time
import itertools import itertools
import argparse
import torch import torch
from scipy.io import loadmat from scipy.io import loadmat
from torch_scatter import gather_coo, gather_csr from torch_scatter import gather_coo, gather_csr
from scatter_segment import iters, device, sizes from scatter_segment import iters, sizes
from scatter_segment import short_rows, long_rows, download, bold from scatter_segment import short_rows, long_rows, download, bold
...@@ -14,13 +17,13 @@ from scatter_segment import short_rows, long_rows, download, bold ...@@ -14,13 +17,13 @@ from scatter_segment import short_rows, long_rows, download, bold
def correctness(dataset): def correctness(dataset):
group, name = dataset group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long) rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1 dim_size = rowptr.size(0) - 1
for size in sizes[1:]: for size in sizes[1:]:
try: try:
x = torch.randn((dim_size, size), device=device) x = torch.randn((dim_size, size), device=args.device)
x = x.squeeze(-1) if size == 1 else x x = x.squeeze(-1) if size == 1 else x
out1 = x.index_select(0, row) out1 = x.index_select(0, row)
...@@ -33,76 +36,49 @@ def correctness(dataset): ...@@ -33,76 +36,49 @@ def correctness(dataset):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@torch.no_grad()
def time_func(func, x):
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
func(x)
torch.cuda.synchronize()
return time.perf_counter() - t
except RuntimeError:
torch.cuda.empty_cache()
return float('inf')
@torch.no_grad() @torch.no_grad()
def timing(dataset): def timing(dataset):
group, name = dataset group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(device, torch.long) rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(device, torch.long) row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
dim_size = rowptr.size(0) - 1 dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size avg_row_len = row.size(0) / dim_size
select = lambda x: x.index_select(0, row)
gather = lambda x: x.gather(0, row.view(-1, 1).expand(-1, x.size(1)))
gat_coo = lambda x: gather_coo(x, row)
gat_csr = lambda x: gather_csr(x, rowptr)
t1, t2, t3, t4 = [], [], [], [] t1, t2, t3, t4 = [], [], [], []
for size in sizes: for size in sizes:
try: try:
x = torch.randn((dim_size, size), device=device) x = torch.randn((dim_size, size), device=args.device)
row_expand = row.view(-1, 1).expand(-1, x.size(-1))
x = x.squeeze(-1) if size == 1 else x t1 += [time_func(select, x)]
row_expand = row_expand.squeeze(-1) if size == 1 else row_expand t2 += [time_func(gather, x)]
t3 += [time_func(gat_coo, x)]
try: t4 += [time_func(gat_csr, x)]
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = x.index_select(0, row)
del out
torch.cuda.synchronize()
t1.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t1.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = x.gather(0, row_expand)
del out
torch.cuda.synchronize()
t2.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t2.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = gather_coo(x, row)
del out
torch.cuda.synchronize()
t3.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t3.append(float('inf'))
try:
torch.cuda.synchronize()
t = time.perf_counter()
for _ in range(iters):
out = gather_csr(x, rowptr)
del out
torch.cuda.synchronize()
t4.append(time.perf_counter() - t)
except RuntimeError:
torch.cuda.empty_cache()
t4.append(float('inf'))
del x del x
except RuntimeError: except RuntimeError:
torch.cuda.empty_cache() torch.cuda.empty_cache()
for t in (t1, t2, t3): for t in (t1, t2, t3, t4):
t.append(float('inf')) t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4]) ts = torch.tensor([t1, t2, t3, t4])
...@@ -125,8 +101,12 @@ def timing(dataset): ...@@ -125,8 +101,12 @@ def timing(dataset):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
for _ in range(10): # Warmup. for _ in range(10): # Warmup.
torch.randn(100, 100, device=device).sum() torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows): for dataset in itertools.chain(short_rows, long_rows):
download(dataset) download(dataset)
correctness(dataset) correctness(dataset)
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import time import time
import os.path as osp import os.path as osp
import itertools import itertools
import argparse
import argparse
import wget import wget
import torch import torch
from scipy.io import loadmat from scipy.io import loadmat
...@@ -13,12 +13,6 @@ import torch_scatter ...@@ -13,12 +13,6 @@ import torch_scatter
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr from torch_scatter import segment_coo, segment_csr
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
iters = 20 iters = 20
sizes = [1, 16, 32, 64, 128, 256, 512] sizes = [1, 16, 32, 64, 128, 256, 512]
...@@ -94,6 +88,7 @@ def correctness(dataset): ...@@ -94,6 +88,7 @@ def correctness(dataset):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@torch.no_grad()
def time_func(func, x): def time_func(func, x):
try: try:
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -184,6 +179,13 @@ def timing(dataset): ...@@ -184,6 +179,13 @@ def timing(dataset):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True,
choices=['add', 'mean', 'min', 'max'])
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
for _ in range(10): # Warmup. for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum() torch.randn(100, 100, device=args.device).sum()
for dataset in itertools.chain(short_rows, long_rows): for dataset in itertools.chain(short_rows, long_rows):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment