Commit 7e82bc0e authored by rusty1s's avatar rusty1s
Browse files

update benchmark script

parent f98ff7e8
......@@ -63,25 +63,25 @@ def correctness(dataset):
assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
out1, arg_out1 = scatter_max(x, row, dim=0, dim_size=dim_size)
out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
# out1, arg_out1 = scatter_max(x, row, dim=0, dim_size=dim_size)
# out3, arg_out3 = segment_csr(x, rowptr, reduce='max')
# print(out1[:5])
# print(out3[:5])
nnz = (out1 != out3).nonzero().flatten()
# nnz = (out1 != out3).nonzero().flatten()
nnz1 = nnz[0].item()
print(rowptr[nnz1], rowptr[nnz1 + 1])
# nnz1 = nnz[0].item()
# print(rowptr[nnz1], rowptr[nnz1 + 1])
print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
# print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
# print(x[rowptr[nnz1]:rowptr[nnz1 + 1]])
print(out1[nnz1])
print(out3[nnz1])
# print(out1[nnz1])
# print(out3[nnz1])
assert torch.allclose(out1, out3, atol=1e-4)
assert torch.all(arg_out1 == arg_out3)
# assert torch.allclose(out1, out3, atol=1e-4)
# assert torch.all(arg_out1 == arg_out3)
except RuntimeError:
torch.cuda.empty_cache()
......@@ -225,4 +225,4 @@ if __name__ == '__main__':
for dataset in itertools.chain(short_rows, long_rows):
download(dataset)
correctness(dataset)
# timing(dataset)
timing(dataset)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment