Commit 7f8aac48 authored by rusty1s's avatar rusty1s
Browse files

final clean up

parent c6ca46d8
......@@ -7,21 +7,19 @@ from .utils import devices
@pytest.mark.parametrize('device', devices)
def test_metis(device):
weighted_mat = SparseTensor.from_dense(torch.randn((6, 6), device=device))
mat, partptr, perm = weighted_mat.partition(num_parts=2, recursive=False, sort_strategy=True)
assert partptr.numel() == 3
assert perm.numel() == 6
value1 = torch.randn(6 * 6, device=device).view(6, 6)
value2 = torch.arange(6 * 6, dtype=torch.long, device=device).view(6, 6)
value3 = torch.ones(6 * 6, device=device).view(6, 6)
mat, partptr, perm = weighted_mat.partition(num_parts=2, recursive=False, sort_strategy=False)
assert partptr.numel() == 3
assert perm.numel() == 6
for value in [value1, value2, value3]:
mat = SparseTensor.from_dense(value)
unweighted_mat = SparseTensor.from_dense(torch.ones((6, 6), device=device))
mat, partptr, perm = unweighted_mat.partition(num_parts=2, recursive=True, sort_strategy=True)
_, partptr, perm = mat.partition(num_parts=2, recursive=False,
weighted=True)
assert partptr.numel() == 3
assert perm.numel() == 6
unweighted_mat = unweighted_mat.set_value(None)
mat, partptr, perm = unweighted_mat.partition(num_parts=2, recursive=True)
_, partptr, perm = mat.partition(num_parts=2, recursive=False,
weighted=False)
assert partptr.numel() == 3
assert perm.numel() == 6
from typing import Tuple
from typing import Tuple, Optional
import torch
from torch_sparse.tensor import SparseTensor
from torch_sparse.permute import permute
def cartesian1d(x, y):
a1, a2 = torch.meshgrid([x, y])
coos = torch.stack([a1, a2]).T.reshape(-1, 2)
return coos.split(1, dim=1)
def metis_weight1(x):
sorted_x = x.sort()[0]
diff = sorted_x[1:] - sorted_x[:-1]
def weight2metis(weight: torch.Tensor) -> Optional[torch.Tensor]:
sorted_weight = weight.sort()[0]
diff = sorted_weight[1:] - sorted_weight[:-1]
if diff.sum() == 0:
return None
xmin, xmax = sorted_x[[0, -1]]
srange = xmax - xmin
weight_min, weight_max = sorted_weight[0], sorted_weight[-1]
srange = weight_max - weight_min
min_diff = diff.min()
scale = (min_diff / srange).item()
tick, arange = scale.as_integer_ratio()
x_ratio = (x - xmin) / srange
return (x_ratio * arange + tick).long()
weight_ratio = (weight - weight_min).div_(srange).mul_(arange).add_(tick)
return weight_ratio.to(torch.long)
def metis_weight2(x):
t1, t2 = cartesian1d(x, x)
diff = t1 - t2
diff = diff[diff != 0]
if len(diff) == 0:
return None
xmin, xmax = x.min(), x.max()
srange = xmax - xmin
min_diff = diff.abs().min()
scale = (min_diff / srange).item()
tick, arange = scale.as_integer_ratio()
x_ratio = (x - xmin) / srange
return (x_ratio * arange + tick).long()
def metis_weight(x, sort_strategy=True):
return metis_weight1(x) if sort_strategy else metis_weight2(x)
def partition(src: SparseTensor, num_parts: int, recursive: bool = False, sort_strategy=True,
def partition(src: SparseTensor, num_parts: int, recursive: bool = False,
weighted=False
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr()
rowptr, col = rowptr.cpu(), col.cpu()
if value is not None and value.dim() == 1:
value = value.detach().cpu()
value = metis_weight(value, sort_strategy)
if value is not None and weighted:
assert value.numel() == col.numel()
value = value.view(-1).detach().cpu()
if value.is_floating_point():
value = weight2metis(value)
else:
value = None
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
recursive)
cluster = cluster.to(src.device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment