Commit 1006514c authored by rusty1s's avatar rusty1s
Browse files

test with pytorch scatter

parent f056396b
...@@ -115,10 +115,20 @@ def timing(dataset): ...@@ -115,10 +115,20 @@ def timing(dataset):
dim_size = rowptr.size(0) - 1 dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size avg_row_len = row.size(0) / dim_size
def sca_row(x): def sca1_row(x):
out = x.new_zeros(dim_size, *x.size()[1:])
row_tmp = row.view(-1, 1).expand_as(x) if x.dim() > 1 else row
return out.scatter_add_(0, row_tmp, x)
def sca1_col(x):
out = x.new_zeros(dim_size, *x.size()[1:])
row2_tmp = row2.view(-1, 1).expand_as(x) if x.dim() > 1 else row2
return out.scatter_add_(0, row2_tmp, x)
def sca2_row(x):
return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce) return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)
def sca_col(x): def sca2_col(x):
return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce) return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)
def seg_coo(x): def seg_coo(x):
...@@ -133,17 +143,19 @@ def timing(dataset): ...@@ -133,17 +143,19 @@ def timing(dataset):
def dense2(x): def dense2(x):
return getattr(torch, args.reduce)(x, dim=-1) return getattr(torch, args.reduce)(x, dim=-1)
t1, t2, t3, t4, t5, t6 = [], [], [], [], [], [] t1, t2, t3, t4, t5, t6, t7, t8 = [], [], [], [], [], [], [], []
for size in sizes: for size in sizes:
try: try:
x = torch.randn((row.size(0), size), device=args.device) x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x x = x.squeeze(-1) if size == 1 else x
t1 += [time_func(sca_row, x)] t1 += [time_func(sca1_row, x)]
t2 += [time_func(sca_col, x)] t2 += [time_func(sca1_col, x)]
t3 += [time_func(seg_coo, x)] t3 += [time_func(sca2_row, x)]
t4 += [time_func(seg_csr, x)] t4 += [time_func(sca2_col, x)]
t5 += [time_func(seg_coo, x)]
t6 += [time_func(seg_csr, x)]
del x del x
...@@ -151,16 +163,16 @@ def timing(dataset): ...@@ -151,16 +163,16 @@ def timing(dataset):
if 'out of memory' not in str(e): if 'out of memory' not in str(e):
raise RuntimeError(e) raise RuntimeError(e)
torch.cuda.empty_cache() torch.cuda.empty_cache()
for t in (t1, t2, t3, t4): for t in (t1, t2, t3, t4, t5, t6):
t.append(float('inf')) t.append(float('inf'))
try: try:
x = torch.randn((dim_size, int(avg_row_len + 1), size), x = torch.randn((dim_size, int(avg_row_len + 1), size),
device=args.device) device=args.device)
t5 += [time_func(dense1, x)] t7 += [time_func(dense1, x)]
x = x.view(dim_size, size, int(avg_row_len + 1)) x = x.view(dim_size, size, int(avg_row_len + 1))
t6 += [time_func(dense2, x)] t8 += [time_func(dense2, x)]
del x del x
...@@ -168,10 +180,10 @@ def timing(dataset): ...@@ -168,10 +180,10 @@ def timing(dataset):
if 'out of memory' not in str(e): if 'out of memory' not in str(e):
raise RuntimeError(e) raise RuntimeError(e)
torch.cuda.empty_cache() torch.cuda.empty_cache()
for t in (t5, t6): for t in (t7, t8):
t.append(float('inf')) t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4, t5, t6]) ts = torch.tensor([t1, t2, t3, t4, t5, t6, t7, t8])
winner = torch.zeros_like(ts, dtype=torch.bool) winner = torch.zeros_like(ts, dtype=torch.bool)
winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1 winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
winner = winner.tolist() winner = winner.tolist()
...@@ -179,29 +191,33 @@ def timing(dataset): ...@@ -179,29 +191,33 @@ def timing(dataset):
name = f'{group}/{name}' name = f'{group}/{name}'
print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):') print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
print('\t'.join([' '] + [f'{size:>5}' for size in sizes])) print('\t'.join([' '] + [f'{size:>5}' for size in sizes]))
print('\t'.join([bold('SCA_ROW')] + print('\t'.join([bold('SCA1_R ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])])) [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
print('\t'.join([bold('SCA_COL')] + print('\t'.join([bold('SCA1_C ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])])) [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
print('\t'.join([bold('SEG_COO')] + print('\t'.join([bold('SCA2_R ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])])) [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
print('\t'.join([bold('SEG_CSR')] + print('\t'.join([bold('SCA2_C ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])])) [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
print('\t'.join([bold('DENSE1 ')] + print('\t'.join([bold('SEG_COO')] +
[bold(f'{t:.5f}', f) for t, f in zip(t5, winner[4])])) [bold(f'{t:.5f}', f) for t, f in zip(t5, winner[4])]))
print('\t'.join([bold('DENSE2 ')] + print('\t'.join([bold('SEG_CSR')] +
[bold(f'{t:.5f}', f) for t, f in zip(t6, winner[5])])) [bold(f'{t:.5f}', f) for t, f in zip(t6, winner[5])]))
print('\t'.join([bold('DENSE1 ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t7, winner[6])]))
print('\t'.join([bold('DENSE2 ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t8, winner[7])]))
print() print()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True, parser.add_argument('--reduce', type=str, required=True,
choices=['sum', 'add', 'mean', 'min', 'max']) choices=['sum', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true') parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args() args = parser.parse_args()
iters = 1 if args.device == 'cpu' else 20 iters = 1 if args.device == 'cpu' else 50
sizes = [1, 16, 32, 64, 128, 256, 512] sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes sizes = sizes[:3] if args.device == 'cpu' else sizes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment