"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "d50ed5219d00d4089621f1ca02de0e8986ce270b"
Commit 4e4b69bd authored by rusty1s's avatar rusty1s
Browse files

added benchmark backward option

parent aae9e125
...@@ -36,13 +36,21 @@ def correctness(dataset): ...@@ -36,13 +36,21 @@ def correctness(dataset):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@torch.no_grad()
def time_func(func, x): def time_func(func, x):
try: try:
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.perf_counter() t = time.perf_counter()
for _ in range(iters):
func(x) if not args.with_backward:
with torch.no_grad():
for _ in range(iters):
func(x)
else:
x = x.requires_grad_()
for _ in range(iters):
out = func(x)
torch.autograd.grad(out, x, out, only_inputs=True)
torch.cuda.synchronize() torch.cuda.synchronize()
return time.perf_counter() - t return time.perf_counter() - t
except RuntimeError: except RuntimeError:
...@@ -50,7 +58,6 @@ def time_func(func, x): ...@@ -50,7 +58,6 @@ def time_func(func, x):
return float('inf') return float('inf')
@torch.no_grad()
def timing(dataset): def timing(dataset):
group, name = dataset group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
...@@ -102,6 +109,7 @@ def timing(dataset): ...@@ -102,6 +109,7 @@ def timing(dataset):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args() args = parser.parse_args()
......
...@@ -68,18 +68,18 @@ def correctness(dataset): ...@@ -68,18 +68,18 @@ def correctness(dataset):
x = x.abs_().mul_(-1) x = x.abs_().mul_(-1)
out1, arg_out1 = scatter_min(x, row, 0, torch.zeros_like(out1)) out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1))
out2, arg_out2 = segment_coo(x, row, reduce='min') out2, _ = segment_coo(x, row, reduce='min')
out3, arg_out3 = segment_csr(x, rowptr, reduce='min') out3, _ = segment_csr(x, rowptr, reduce='min')
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
x = x.abs_() x = x.abs_()
out1, arg_out1 = scatter_max(x, row, 0, torch.zeros_like(out1)) out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1))
out2, arg_out2 = segment_coo(x, row, reduce='max') out2, _ = segment_coo(x, row, reduce='max')
out3, arg_out3 = segment_csr(x, rowptr, reduce='max') out3, _ = segment_csr(x, rowptr, reduce='max')
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
...@@ -88,13 +88,22 @@ def correctness(dataset): ...@@ -88,13 +88,22 @@ def correctness(dataset):
torch.cuda.empty_cache() torch.cuda.empty_cache()
@torch.no_grad()
def time_func(func, x): def time_func(func, x):
try: try:
torch.cuda.synchronize() torch.cuda.synchronize()
t = time.perf_counter() t = time.perf_counter()
for _ in range(iters):
func(x) if not args.with_backward:
with torch.no_grad():
for _ in range(iters):
func(x)
else:
x = x.requires_grad_()
for _ in range(iters):
out = func(x)
out = out[0] if isinstance(out, tuple) else out
torch.autograd.grad(out, x, out, only_inputs=True)
torch.cuda.synchronize() torch.cuda.synchronize()
return time.perf_counter() - t return time.perf_counter() - t
except RuntimeError: except RuntimeError:
...@@ -102,7 +111,6 @@ def time_func(func, x): ...@@ -102,7 +111,6 @@ def time_func(func, x):
return float('inf') return float('inf')
@torch.no_grad()
def timing(dataset): def timing(dataset):
group, name = dataset group, name = dataset
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
...@@ -182,6 +190,7 @@ if __name__ == '__main__': ...@@ -182,6 +190,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True, parser.add_argument('--reduce', type=str, required=True,
choices=['add', 'mean', 'min', 'max']) choices=['add', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args() args = parser.parse_args()
args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce args.dense_reduce = 'sum' if args.reduce == 'add' else args.reduce
......
...@@ -55,6 +55,8 @@ class SegmentCOO(torch.autograd.Function): ...@@ -55,6 +55,8 @@ class SegmentCOO(torch.autograd.Function):
count = arg_out count = arg_out
count = gather_cuda.gather_coo( count = gather_cuda.gather_coo(
count, index, count.new_empty(src_size[:index.dim()])) count, index, count.new_empty(src_size[:index.dim()]))
for _ in range(grad_out.dim() - index.dim()):
count = count.unsqueeze(-1)
grad_src.div_(count) grad_src.div_(count)
elif ctx.reduce == 'min' or ctx.reduce == 'max': elif ctx.reduce == 'min' or ctx.reduce == 'max':
src_size[index.dim() - 1] += 1 src_size[index.dim() - 1] += 1
...@@ -62,7 +64,7 @@ class SegmentCOO(torch.autograd.Function): ...@@ -62,7 +64,7 @@ class SegmentCOO(torch.autograd.Function):
index.dim() - 1, arg_out, grad_out) index.dim() - 1, arg_out, grad_out)
grad_src = grad_src.narrow(index.dim() - 1, 0, grad_src = grad_src.narrow(index.dim() - 1, 0,
src_size[index.dim() - 1] - 1) src_size[index.dim() - 1] - 1)
return grad_src, None, None, None return grad_src, None, None, None, None
class SegmentCSR(torch.autograd.Function): class SegmentCSR(torch.autograd.Function):
...@@ -96,6 +98,8 @@ class SegmentCSR(torch.autograd.Function): ...@@ -96,6 +98,8 @@ class SegmentCSR(torch.autograd.Function):
count = (indptr2 - indptr1).to(grad_src.dtype) count = (indptr2 - indptr1).to(grad_src.dtype)
count = gather_cuda.gather_csr( count = gather_cuda.gather_csr(
count, indptr, count.new_empty(src_size[:indptr.dim()])) count, indptr, count.new_empty(src_size[:indptr.dim()]))
for _ in range(grad_out.dim() - indptr.dim()):
count = count.unsqueeze(-1)
grad_src.div_(count) grad_src.div_(count)
elif ctx.reduce == 'min' or ctx.reduce == 'max': elif ctx.reduce == 'min' or ctx.reduce == 'max':
src_size[indptr.dim() - 1] += 1 src_size[indptr.dim() - 1] += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment