Commit de991ced authored by rusty1s's avatar rusty1s
Browse files

correctness check

parent 0934609b
......@@ -35,7 +35,27 @@ def bold(text, flag=True):
@torch.no_grad()
def correctness(dataset):
pass
group, name = dataset
mat_scipy = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
row = torch.from_numpy(mat_scipy.tocoo().row).to(args.device, torch.long)
col = torch.from_numpy(mat_scipy.tocoo().col).to(args.device, torch.long)
mat = SparseTensor(row=row, col=col, sparse_sizes=mat_scipy.shape)
mat.fill_cache_()
mat_pytorch = mat.to_torch_sparse_coo_tensor().coalesce()
for size in sizes:
try:
x = torch.randn((mat.size(1), size), device=args.device)
out1 = mat @ x
out2 = mat_pytorch @ x
assert torch.allclose(out1, out2, atol=1e-4)
except RuntimeError as e:
if 'out of memory' not in str(e):
raise RuntimeError(e)
torch.cuda.empty_cache()
def time_func(func, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment