"...text-generation-inference.git" did not exist on "bfddfa5955dd6558814d313e4364ddf534848632"
Commit 713fb60a authored by rusty1s's avatar rusty1s
Browse files

small fixes

parent 5a485e98
...@@ -32,9 +32,9 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size, ...@@ -32,9 +32,9 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
cudaSetDevice(pos.get_device()); cudaSetDevice(pos.get_device());
if (optional_start.has_value()) if (optional_start.has_value())
CHECK_CPU(optional_start.value()); CHECK_CUDA(optional_start.value());
if (optional_start.has_value()) if (optional_start.has_value())
CHECK_CPU(optional_start.value()); CHECK_CUDA(optional_start.value());
pos = pos.view({pos.size(0), -1}).contiguous(); pos = pos.view({pos.size(0), -1}).contiguous();
size = size.contiguous(); size = size.contiguous();
......
...@@ -36,15 +36,21 @@ def graclus_cluster(row: torch.Tensor, col: torch.Tensor, ...@@ -36,15 +36,21 @@ def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
mask = row == col mask = row == col
row, col = row[mask], col[mask] row, col = row[mask], col[mask]
# Randomly shuffle nodes.
if weight is not None: if weight is not None:
perm = torch.randperm(row.size(0), device=row.device) weight = weight[mask]
# Randomly shuffle nodes.
if weight is None:
perm = torch.randperm(row.size(0), dtype=torch.long, device=row.device)
row, col = row[perm], col[perm] row, col = row[perm], col[perm]
# To CSR. # To CSR.
perm = torch.argsort(row) perm = torch.argsort(row)
row, col = row[perm], col[perm] row, col = row[perm], col[perm]
if weight is not None:
weight = weight[perm]
deg = row.new_zeros(num_nodes) deg = row.new_zeros(num_nodes)
deg.scatter_add_(0, row, torch.ones_like(row)) deg.scatter_add_(0, row, torch.ones_like(row))
rowptr = row.new_zeros(num_nodes + 1) rowptr = row.new_zeros(num_nodes + 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment