Commit 7e82bc0e authored by rusty1s's avatar rusty1s
Browse files

update benchmark script

parent f98ff7e8
...@@ -63,25 +63,25 @@ def correctness(dataset): ...@@ -63,25 +63,25 @@ def correctness(dataset):
assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
out1, arg_out1 = scatter_max(x, row, dim=0, dim_size=dim_size) # out1, arg_out1 = scatter_max(x, row, dim=0, dim_size=dim_size)
out3, arg_out3 = segment_csr(x, rowptr, reduce='max') # out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
# print(out1[:5]) # print(out1[:5])
# print(out3[:5]) # print(out3[:5])
nnz = (out1 != out3).nonzero().flatten() # nnz = (out1 != out3).nonzero().flatten()
nnz1 = nnz[0].item() # nnz1 = nnz[0].item()
print(rowptr[nnz1], rowptr[nnz1 + 1]) # print(rowptr[nnz1], rowptr[nnz1 + 1])
print(x[rowptr[nnz1]:rowptr[nnz1 + 1]]) # print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
print(x[rowptr[nnz1]:rowptr[nnz1 + 1]]) # print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
print(out1[nnz1]) # print(out1[nnz1])
print(out3[nnz1]) # print(out3[nnz1])
assert torch.allclose(out1, out3, atol=1e-4) # assert torch.allclose(out1, out3, atol=1e-4)
assert torch.all(arg_out1 == arg_out3) # assert torch.all(arg_out1 == arg_out3)
except RuntimeError: except RuntimeError:
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -225,4 +225,4 @@ if __name__ == '__main__': ...@@ -225,4 +225,4 @@ if __name__ == '__main__':
for dataset in itertools.chain(short_rows, long_rows): for dataset in itertools.chain(short_rows, long_rows):
download(dataset) download(dataset)
correctness(dataset) correctness(dataset)
# timing(dataset) timing(dataset)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment