Commit 1006514c authored by rusty1s's avatar rusty1s
Browse files

test with pytorch scatter

parent f056396b
......@@ -115,10 +115,20 @@ def timing(dataset):
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size
def sca_row(x):
def sca1_row(x):
out = x.new_zeros(dim_size, *x.size()[1:])
row_tmp = row.view(-1, 1).expand_as(x) if x.dim() > 1 else row
return out.scatter_add_(0, row_tmp, x)
def sca1_col(x):
out = x.new_zeros(dim_size, *x.size()[1:])
row2_tmp = row2.view(-1, 1).expand_as(x) if x.dim() > 1 else row2
return out.scatter_add_(0, row2_tmp, x)
def sca2_row(x):
return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)
def sca_col(x):
def sca2_col(x):
return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)
def seg_coo(x):
......@@ -133,17 +143,19 @@ def timing(dataset):
def dense2(x):
return getattr(torch, args.reduce)(x, dim=-1)
t1, t2, t3, t4, t5, t6 = [], [], [], [], [], []
t1, t2, t3, t4, t5, t6, t7, t8 = [], [], [], [], [], [], [], []
for size in sizes:
try:
x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x
t1 += [time_func(sca_row, x)]
t2 += [time_func(sca_col, x)]
t3 += [time_func(seg_coo, x)]
t4 += [time_func(seg_csr, x)]
t1 += [time_func(sca1_row, x)]
t2 += [time_func(sca1_col, x)]
t3 += [time_func(sca2_row, x)]
t4 += [time_func(sca2_col, x)]
t5 += [time_func(seg_coo, x)]
t6 += [time_func(seg_csr, x)]
del x
......@@ -151,16 +163,16 @@ def timing(dataset):
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
for t in (t1, t2, t3, t4):
for t in (t1, t2, t3, t4, t5, t6):
t.append(float('inf'))
try:
x = torch.randn((dim_size, int(avg_row_len + 1), size),
device=args.device)
t5 += [time_func(dense1, x)]
t7 += [time_func(dense1, x)]
x = x.view(dim_size, size, int(avg_row_len + 1))
t6 += [time_func(dense2, x)]
t8 += [time_func(dense2, x)]
del x
......@@ -168,10 +180,10 @@ def timing(dataset):
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
for t in (t5, t6):
for t in (t7, t8):
t.append(float('inf'))
ts = torch.tensor([t1, t2, t3, t4, t5, t6])
ts = torch.tensor([t1, t2, t3, t4, t5, t6, t7, t8])
winner = torch.zeros_like(ts, dtype=torch.bool)
winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1
winner = winner.tolist()
......@@ -179,29 +191,33 @@ def timing(dataset):
name = f'{group}/{name}'
print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):')
print('\t'.join([' '] + [f'{size:>5}' for size in sizes]))
print('\t'.join([bold('SCA_ROW')] +
print('\t'.join([bold('SCA1_R ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])]))
print('\t'.join([bold('SCA_COL')] +
print('\t'.join([bold('SCA1_C ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])]))
print('\t'.join([bold('SEG_COO')] +
print('\t'.join([bold('SCA2_R ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])]))
print('\t'.join([bold('SEG_CSR')] +
print('\t'.join([bold('SCA2_C ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])]))
print('\t'.join([bold('DENSE1 ')] +
print('\t'.join([bold('SEG_COO')] +
[bold(f'{t:.5f}', f) for t, f in zip(t5, winner[4])]))
print('\t'.join([bold('DENSE2 ')] +
print('\t'.join([bold('SEG_CSR')] +
[bold(f'{t:.5f}', f) for t, f in zip(t6, winner[5])]))
print('\t'.join([bold('DENSE1 ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t7, winner[6])]))
print('\t'.join([bold('DENSE2 ')] +
[bold(f'{t:.5f}', f) for t, f in zip(t8, winner[7])]))
print()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True,
choices=['sum', 'add', 'mean', 'min', 'max'])
choices=['sum', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
iters = 1 if args.device == 'cpu' else 20
iters = 1 if args.device == 'cpu' else 50
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment