"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a51b6cc86a5bef62283562d49497a4f3e0b134d8"
Commit c6ca46d8 authored by bowendeng's avatar bowendeng
Browse files

add one optional arg to 'partition'

parent fb3459ad
...@@ -8,12 +8,16 @@ from .utils import devices ...@@ -8,12 +8,16 @@ from .utils import devices
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('device', devices)
def test_metis(device): def test_metis(device):
weighted_mat = SparseTensor.from_dense(torch.randn((6, 6), device=device)) weighted_mat = SparseTensor.from_dense(torch.randn((6, 6), device=device))
mat, partptr, perm = weighted_mat.partition(num_parts=2, recursive=False) mat, partptr, perm = weighted_mat.partition(num_parts=2, recursive=False, sort_strategy=True)
assert partptr.numel() == 3
assert perm.numel() == 6
mat, partptr, perm = weighted_mat.partition(num_parts=2, recursive=False, sort_strategy=False)
assert partptr.numel() == 3 assert partptr.numel() == 3
assert perm.numel() == 6 assert perm.numel() == 6
unweighted_mat = SparseTensor.from_dense(torch.ones((6, 6), device=device)) unweighted_mat = SparseTensor.from_dense(torch.ones((6, 6), device=device))
mat, partptr, perm = unweighted_mat.partition(num_parts=2, recursive=True) mat, partptr, perm = unweighted_mat.partition(num_parts=2, recursive=True, sort_strategy=True)
assert partptr.numel() == 3 assert partptr.numel() == 3
assert perm.numel() == 6 assert perm.numel() == 6
......
...@@ -11,27 +11,46 @@ def cartesian1d(x, y): ...@@ -11,27 +11,46 @@ def cartesian1d(x, y):
return coos.split(1, dim=1) return coos.split(1, dim=1)
def metis_weight(x): def metis_weight1(x):
sorted_x = x.sort()[0]
diff = sorted_x[1:] - sorted_x[:-1]
if diff.sum() == 0:
return None
xmin, xmax = sorted_x[[0, -1]]
srange = xmax - xmin
min_diff = diff.min()
scale = (min_diff / srange).item()
tick, arange = scale.as_integer_ratio()
x_ratio = (x - xmin) / srange
return (x_ratio * arange + tick).long()
def metis_weight2(x):
t1, t2 = cartesian1d(x, x) t1, t2 = cartesian1d(x, x)
diff = t1 - t2 diff = t1 - t2
diff = diff[diff != 0] diff = diff[diff != 0]
if len(diff) == 0: if len(diff) == 0:
return None return None
res = diff.abs().min() xmin, xmax = x.min(), x.max()
bod = x.max() - x.min() srange = xmax - xmin
scale = (res / bod).item() min_diff = diff.abs().min()
scale = (min_diff / srange).item()
tick, arange = scale.as_integer_ratio() tick, arange = scale.as_integer_ratio()
x_ratio = (x - x.min()) / bod x_ratio = (x - xmin) / srange
return (x_ratio * arange + tick).long() return (x_ratio * arange + tick).long()
def partition(src: SparseTensor, num_parts: int, recursive: bool = False def metis_weight(x, sort_strategy=True):
return metis_weight1(x) if sort_strategy else metis_weight2(x)
def partition(src: SparseTensor, num_parts: int, recursive: bool = False, sort_strategy=True,
) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]: ) -> Tuple[SparseTensor, torch.Tensor, torch.Tensor]:
rowptr, col, value = src.csr() rowptr, col, value = src.csr()
rowptr, col = rowptr.cpu(), col.cpu() rowptr, col = rowptr.cpu(), col.cpu()
if value is not None and value.dim() == 1: if value is not None and value.dim() == 1:
value = value.detach().cpu() value = value.detach().cpu()
value = metis_weight(value) value = metis_weight(value, sort_strategy)
cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts, cluster = torch.ops.torch_sparse.partition(rowptr, col, value, num_parts,
recursive) recursive)
cluster = cluster.to(src.device()) cluster = cluster.to(src.device())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment