Commit 34045b9a authored by rusty1s's avatar rusty1s
Browse files

linting

parent 4e4b69bd
# flake8: noqa
import time import time
import itertools import itertools
...@@ -66,10 +64,17 @@ def timing(dataset): ...@@ -66,10 +64,17 @@ def timing(dataset):
dim_size = rowptr.size(0) - 1 dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size avg_row_len = row.size(0) / dim_size
select = lambda x: x.index_select(0, row) def select(x):
gather = lambda x: x.gather(0, row.view(-1, 1).expand(-1, x.size(1))) return x.index_select(0, row)
gat_coo = lambda x: gather_coo(x, row)
gat_csr = lambda x: gather_csr(x, rowptr) def gather(x):
return x.gather(0, row.view(-1, 1).expand(-1, x.size(1)))
def gat_coo(x):
return gather_coo(x, row)
def gat_csr(x):
return gather_csr(x, rowptr)
t1, t2, t3, t4 = [], [], [], [] t1, t2, t3, t4 = [], [], [], []
for size in sizes: for size in sizes:
......
# flake8: noqa
import time import time
import os.path as osp import os.path as osp
import itertools import itertools
...@@ -120,14 +118,25 @@ def timing(dataset): ...@@ -120,14 +118,25 @@ def timing(dataset):
dim_size = rowptr.size(0) - 1 dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size avg_row_len = row.size(0) / dim_size
sca_row = lambda x: getattr(torch_scatter, f'scatter_{args.reduce}')( def sca_row(x):
x, row, dim=0, dim_size=dim_size) op = getattr(torch_scatter, f'scatter_{args.reduce}')
sca_col = lambda x: getattr(torch_scatter, f'scatter_{args.reduce}')( return op(x, row, dim=0, dim_size=dim_size)
x, row_perm, dim=0, dim_size=dim_size)
seg_coo = lambda x: segment_coo(x, row, reduce=args.reduce) def sca_col(x):
seg_csr = lambda x: segment_csr(x, rowptr, reduce=args.reduce) op = getattr(torch_scatter, f'scatter_{args.reduce}')
dense1 = lambda x: getattr(torch, args.dense_reduce)(x, dim=-2) return op(x, row_perm, dim=0, dim_size=dim_size)
dense2 = lambda x: getattr(torch, args.dense_reduce)(x, dim=-1)
def seg_coo(x):
return segment_coo(x, row, reduce=args.reduce)
def seg_csr(x):
return segment_csr(x, rowptr, reduce=args.reduce)
def dense1(x):
return getattr(torch, args.dense_reduce)(x, dim=-2)
def dense2(x):
return getattr(torch, args.dense_reduce)(x, dim=-1)
t1, t2, t3, t4, t5, t6 = [], [], [], [], [], [] t1, t2, t3, t4, t5, t6 = [], [], [], [], [], []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment