Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
5a485e98
Commit
5a485e98
authored
Mar 12, 2020
by
rusty1s
Browse files
cuda complete
parent
06d9038f
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
259 additions
and
365 deletions
+259
-365
csrc/cpu/graclus_cpu.cpp
csrc/cpu/graclus_cpu.cpp
+1
-4
csrc/cuda/coloring.cuh
csrc/cuda/coloring.cuh
+0
-41
csrc/cuda/fps_cuda.cu
csrc/cuda/fps_cuda.cu
+7
-10
csrc/cuda/graclus.cpp
csrc/cuda/graclus.cpp
+0
-28
csrc/cuda/graclus_cuda.cu
csrc/cuda/graclus_cuda.cu
+233
-0
csrc/cuda/graclus_cuda.h
csrc/cuda/graclus_cuda.h
+6
-0
csrc/cuda/graclus_kernel.cu
csrc/cuda/graclus_kernel.cu
+0
-40
csrc/cuda/knn_cuda.cu
csrc/cuda/knn_cuda.cu
+1
-1
csrc/cuda/proposal.cuh
csrc/cuda/proposal.cuh
+0
-88
csrc/cuda/response.cuh
csrc/cuda/response.cuh
+0
-93
csrc/cuda/utils.cuh
csrc/cuda/utils.cuh
+0
-59
torch_cluster/graclus.py
torch_cluster/graclus.py
+11
-1
No files found.
csrc/cpu/graclus_cpu.cpp
View file @
5a485e98
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
torch
::
Tensor
graclus_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
graclus_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
col
);
CHECK_INPUT
(
rowptr
.
dim
()
==
1
&&
col
.
dim
()
==
1
);
CHECK_INPUT
(
rowptr
.
dim
()
==
1
&&
col
.
dim
()
==
1
);
...
@@ -33,11 +32,9 @@ torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -33,11 +32,9 @@ torch::Tensor graclus_cpu(torch::Tensor rowptr, torch::Tensor col,
out_data
[
u
]
=
u
;
out_data
[
u
]
=
u
;
int64_t
row_start
=
rowptr_data
[
u
],
row_end
=
rowptr_data
[
u
+
1
];
int64_t
row_start
=
rowptr_data
[
u
],
row_end
=
rowptr_data
[
u
+
1
];
auto
edge_perm
=
torch
::
randperm
(
row_end
-
row_start
,
rowptr
.
options
());
auto
edge_perm_data
=
edge_perm
.
data_ptr
<
int64_t
>
();
for
(
auto
e
=
0
;
e
<
row_end
-
row_start
;
e
++
)
{
for
(
auto
e
=
0
;
e
<
row_end
-
row_start
;
e
++
)
{
auto
v
=
col_data
[
row_start
+
e
dge_perm_data
[
e
]
];
auto
v
=
col_data
[
row_start
+
e
];
if
(
out_data
[
v
]
>=
0
)
if
(
out_data
[
v
]
>=
0
)
continue
;
continue
;
...
...
csrc/cuda/coloring.cuh
deleted
100644 → 0
View file @
06d9038f
#pragma once
#include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLUE_PROB 0.53406
__device__
int64_t
done
;
__global__
void
init_done_kernel
()
{
done
=
1
;
}
__global__
void
colorize_kernel
(
int64_t
*
cluster
,
float
*
__restrict__
bernoulli
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int64_t
u
=
index
;
u
<
numel
;
u
+=
stride
)
{
if
(
cluster
[
u
]
<
0
)
{
cluster
[
u
]
=
(
int64_t
)
bernoulli
[
u
]
-
2
;
done
=
0
;
}
}
}
int64_t
colorize
(
at
::
Tensor
cluster
)
{
init_done_kernel
<<<
1
,
1
>>>
();
auto
numel
=
cluster
.
size
(
0
);
auto
props
=
at
::
full
(
numel
,
BLUE_PROB
,
cluster
.
options
().
dtype
(
at
::
kFloat
));
auto
bernoulli
=
props
.
bernoulli
();
colorize_kernel
<<<
BLOCKS
(
numel
),
THREADS
>>>
(
cluster
.
DATA_PTR
<
int64_t
>
(),
bernoulli
.
DATA_PTR
<
float
>
(),
numel
);
int64_t
out
;
cudaMemcpyFromSymbol
(
&
out
,
done
,
sizeof
(
out
),
0
,
cudaMemcpyDeviceToHost
);
return
out
;
}
csrc/cuda/fps_cuda.cu
View file @
5a485e98
...
@@ -6,15 +6,11 @@
...
@@ -6,15 +6,11 @@
#define THREADS 1024
#define THREADS 1024
inline
torch
::
Tensor
get_dist
(
torch
::
Tensor
x
,
int64_t
idx
)
{
return
(
x
-
x
[
idx
]).
norm
(
2
,
1
);
}
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
>
{
template
<
typename
scalar_t
>
struct
Dist
<
scalar_t
>
{
static
inline
__device__
void
compute
(
int64_t
idx
,
int64_t
start_idx
,
static
inline
__device__
void
compute
(
int64_t
idx
,
int64_t
start_idx
,
int64_t
end_idx
,
int64_t
old
,
int64_t
end_idx
,
int64_t
old
,
scalar_t
*
best
,
int64_t
*
best_idx
,
scalar_t
*
best
,
int64_t
*
best_idx
,
const
scalar_t
*
x
,
scalar_t
*
dist
,
const
scalar_t
*
src
,
scalar_t
*
dist
,
scalar_t
*
tmp_dist
,
int64_t
dim
)
{
scalar_t
*
tmp_dist
,
int64_t
dim
)
{
for
(
int64_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
for
(
int64_t
n
=
start_idx
+
idx
;
n
<
end_idx
;
n
+=
THREADS
)
{
...
@@ -23,7 +19,7 @@ template <typename scalar_t> struct Dist<scalar_t> {
...
@@ -23,7 +19,7 @@ template <typename scalar_t> struct Dist<scalar_t> {
__syncthreads
();
__syncthreads
();
for
(
int64_t
i
=
start_idx
*
dim
+
idx
;
i
<
end_idx
*
dim
;
i
+=
THREADS
)
{
for
(
int64_t
i
=
start_idx
*
dim
+
idx
;
i
<
end_idx
*
dim
;
i
+=
THREADS
)
{
scalar_t
d
=
x
[(
old
*
dim
)
+
(
i
%
dim
)]
-
x
[
i
];
scalar_t
d
=
src
[(
old
*
dim
)
+
(
i
%
dim
)]
-
src
[
i
];
atomicAdd
(
&
tmp_dist
[
i
/
dim
],
d
*
d
);
atomicAdd
(
&
tmp_dist
[
i
/
dim
],
d
*
d
);
}
}
...
@@ -39,7 +35,7 @@ template <typename scalar_t> struct Dist<scalar_t> {
...
@@ -39,7 +35,7 @@ template <typename scalar_t> struct Dist<scalar_t> {
};
};
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
fps_kernel
(
const
scalar_t
*
x
,
const
int64_t
*
ptr
,
__global__
void
fps_kernel
(
const
scalar_t
*
src
,
const
int64_t
*
ptr
,
const
int64_t
*
out_ptr
,
const
int64_t
*
start
,
const
int64_t
*
out_ptr
,
const
int64_t
*
start
,
scalar_t
*
dist
,
scalar_t
*
tmp_dist
,
int64_t
*
out
,
scalar_t
*
dist
,
scalar_t
*
tmp_dist
,
int64_t
*
out
,
int64_t
dim
)
{
int64_t
dim
)
{
...
@@ -63,7 +59,7 @@ __global__ void fps_kernel(const scalar_t *x, const int64_t *ptr,
...
@@ -63,7 +59,7 @@ __global__ void fps_kernel(const scalar_t *x, const int64_t *ptr,
__syncthreads
();
__syncthreads
();
Dist
<
scalar_t
,
Dim
>::
compute
(
thread_idx
,
start_idx
,
end_idx
,
out
[
m
-
1
],
Dist
<
scalar_t
,
Dim
>::
compute
(
thread_idx
,
start_idx
,
end_idx
,
out
[
m
-
1
],
&
best
,
&
best_idx
,
x
,
dist
,
tmp_dist
,
dim
);
&
best
,
&
best_idx
,
src
,
dist
,
tmp_dist
,
dim
);
best_dist
[
idx
]
=
best
;
best_dist
[
idx
]
=
best
;
best_dist_idx
[
idx
]
=
best_idx
;
best_dist_idx
[
idx
]
=
best_idx
;
...
@@ -94,6 +90,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
...
@@ -94,6 +90,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
CHECK_CUDA
(
ptr
);
CHECK_CUDA
(
ptr
);
CHECK_INPUT
(
ptr
.
dim
()
==
1
);
CHECK_INPUT
(
ptr
.
dim
()
==
1
);
AT_ASSERTM
(
ratio
>
0
and
ratio
<
1
,
"Invalid input"
);
AT_ASSERTM
(
ratio
>
0
and
ratio
<
1
,
"Invalid input"
);
cudaSetDevice
(
src
.
get_device
());
src
=
src
.
view
({
src
.
size
(
0
),
-
1
}).
contiguous
();
src
=
src
.
view
({
src
.
size
(
0
),
-
1
}).
contiguous
();
ptr
=
ptr
.
contiguous
();
ptr
=
ptr
.
contiguous
();
...
@@ -106,7 +103,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
...
@@ -106,7 +103,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
torch
::
Tensor
start
;
torch
::
Tensor
start
;
if
(
random_start
)
{
if
(
random_start
)
{
start
=
a
t
::
rand
(
batch_size
,
src
.
options
());
start
=
t
orch
::
rand
(
batch_size
,
src
.
options
());
start
=
(
start
*
deg
.
toType
(
torch
::
kFloat
)).
toType
(
torch
::
kLong
);
start
=
(
start
*
deg
.
toType
(
torch
::
kFloat
)).
toType
(
torch
::
kLong
);
}
else
{
}
else
{
start
=
torch
::
zeros
(
batch_size
,
ptr
.
options
());
start
=
torch
::
zeros
(
batch_size
,
ptr
.
options
());
...
@@ -118,7 +115,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
...
@@ -118,7 +115,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, double ratio,
auto
out_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
auto
out_size
=
(
int64_t
*
)
malloc
(
sizeof
(
int64_t
));
cudaMemcpy
(
out_size
,
out_ptr
[
-
1
].
data_ptr
<
int64_t
>
(),
sizeof
(
int64_t
),
cudaMemcpy
(
out_size
,
out_ptr
[
-
1
].
data_ptr
<
int64_t
>
(),
sizeof
(
int64_t
),
cudaMemcpyDeviceToHost
);
cudaMemcpyDeviceToHost
);
auto
out
=
a
t
::
empty
(
out_size
[
0
],
out_ptr
.
options
());
auto
out
=
t
orch
::
empty
(
out_size
[
0
],
out_ptr
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"fps_kernel"
,
[
&
]
{
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"fps_kernel"
,
[
&
]
{
...
...
csrc/cuda/graclus.cpp
deleted
100644 → 0
View file @
06d9038f
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
at
::
Tensor
graclus_cuda
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
num_nodes
);
at
::
Tensor
weighted_graclus_cuda
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
weight
,
int64_t
num_nodes
);
at
::
Tensor
graclus
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
num_nodes
)
{
CHECK_CUDA
(
row
);
CHECK_CUDA
(
col
);
return
graclus_cuda
(
row
,
col
,
num_nodes
);
}
at
::
Tensor
weighted_graclus
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
weight
,
int64_t
num_nodes
)
{
CHECK_CUDA
(
row
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
weight
);
return
weighted_graclus_cuda
(
row
,
col
,
weight
,
num_nodes
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"graclus"
,
&
graclus
,
"Graclus (CUDA)"
);
m
.
def
(
"weighted_graclus"
,
&
weighted_graclus
,
"Weighted Graclus (CUDA)"
);
}
csrc/cuda/graclus_cuda.cu
0 → 100644
View file @
5a485e98
#include "graclus_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.h"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLUE_P 0.53406
torch
::
Tensor
graclus_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_INPUT
(
rowptr
.
dim
()
==
1
&&
col
.
dim
()
==
1
);
if
(
optional_weight
.
has_value
())
{
CHECK_CUDA
(
optional_weight
.
value
());
CHECK_INPUT
(
optional_weight
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_weight
.
value
().
numel
()
==
col
.
numel
());
}
cudaSetDevice
(
rowptr
.
get_device
());
int64_t
num_nodes
=
rowptr
.
numel
()
-
1
;
auto
out
=
torch
::
full
(
num_nodes
,
-
1
,
rowptr
.
options
());
auto
proposal
=
torch
::
full
(
num_nodes
,
-
1
,
rowptr
.
options
());
while
(
!
colorize
(
out
))
{
propose
(
out
,
proposal
,
rowptr
,
col
,
optional_weight
);
respond
(
out
,
proposal
,
rowptr
,
col
,
optional_weight
);
}
return
out
;
}
__device__
int64_t
done_d
;
__global__
void
init_done_kernel
()
{
done_d
=
1
;
}
__global__
void
colorize_kernel
(
int64_t
*
out
,
const
float
*
bernoulli
,
int64_t
numel
)
{
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
if
(
out
[
u
]
<
0
)
{
out
[
u
]
=
(
int64_t
)
bernoulli
[
u
]
-
2
;
done_d
=
0
;
}
}
}
int64_t
colorize
(
torch
::
Tensor
out
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
init_done_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
auto
numel
=
cluster
.
size
(
0
);
auto
props
=
torch
::
full
(
numel
,
BLUE_P
,
out
.
options
().
dtype
(
torch
::
kFloat
));
auto
bernoulli
=
props
.
bernoulli
();
colorize_kernel
<<<
BLOCKS
(
numel
),
THREADS
,
0
,
stream
>>>
(
out
.
data_ptr
<
int64_t
>
(),
bernoulli
.
data_ptr
<
float
>
(),
numel
);
int64_t
done_h
;
cudaMemcpyFromSymbol
(
&
done_h
,
done_d
,
sizeof
(
done_h
),
0
,
cudaMemcpyDeviceToHost
);
return
done_h
;
}
__global__
void
propose_kernel
(
int64_t
*
out
,
int64_t
*
proposal
,
const
int64_t
*
rowptr
,
const
int64_t
*
col
,
int64_t
numel
)
{
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
if
(
out
[
u
]
!=
-
1
)
continue
;
// Only vist blue nodes.
bool
has_unmatched_neighbor
=
false
;
for
(
int64_t
i
=
rowptr
[
u
];
i
<
rowptr
[
u
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
out
[
v
]
<
0
)
has_unmatched_neighbor
=
true
;
// Unmatched neighbor found.
if
(
out
[
v
]
==
-
2
)
{
proposal
[
u
]
=
v
;
// Propose to first red neighbor.
break
;
}
}
if
(
!
has_unmatched_neighbor
)
out
[
u
]
=
u
;
}
}
template
<
typename
scalar_t
>
__global__
void
weighted_propose_kernel
(
int64_t
*
out
,
int64_t
*
proposal
,
const
int64_t
*
rowptr
,
const
int64_t
*
col
,
const
scalar_t
*
weight
,
int64_t
numel
)
{
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
if
(
out
[
u
]
!=
-
1
)
continue
;
// Only vist blue nodes.
bool
has_unmatched_neighbor
=
false
;
int64_t
v_max
=
-
1
;
scalar_t
w_max
=
0
;
for
(
int64_t
i
=
rowptr
[
u
];
i
<
rowptr
[
u
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
out
[
v
]
<
0
)
has_unmatched_neighbor
=
true
;
// Unmatched neighbor found.
// Find maximum weighted red neighbor.
if
(
out
[
v
]
==
-
2
&&
weight
[
i
]
>=
w_max
)
{
v_max
=
v
;
w_max
=
weight
[
i
];
}
}
proposal
[
u
]
=
v_max
;
// Propose.
if
(
!
has_unmatched_neighbor
)
out
[
u
]
=
u
;
}
}
void
propose
(
torch
::
Tensor
out
,
torch
::
Tensor
proposal
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
!
optional_weight
.
has_value
())
{
propose_kernel
<<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
out
.
data_ptr
<
int64_t
>
(),
proposal
.
data_ptr
<
int64_t
>
(),
row
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
out
.
numel
());
}
else
{
auto
=
optional_weight
.
value
();
AT_DISPATCH_ALL_TYPES
(
weight
.
scalar_type
(),
"propose_kernel"
,
[
&
]
{
weighted_propose_kernel
<
scalar_t
>
<<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
out
.
data_ptr
<
int64_t
>
(),
proposal
.
data_ptr
<
int64_t
>
(),
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
out
.
numel
());
});
}
}
__global__
void
respond_kernel
(
int64_t
*
out
,
const
int64_t
*
proposal
,
const
int64_t
*
rowptr
,
const
int64_t
*
col
,
int64_t
numel
)
{
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
if
(
out
[
u
]
!=
-
2
)
continue
;
// Only vist red nodes.
bool
has_unmatched_neighbor
=
false
;
for
(
int64_t
i
=
rowptr
[
u
];
i
<
rowptr
[
u
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
out
[
v
]
<
0
)
has_unmatched_neighbor
=
true
;
// Unmatched neighbor found.
if
(
out
[
v
]
==
-
1
&&
proposal
[
v
]
==
u
)
{
// Match first blue neighbhor v which proposed to u.
out
[
u
]
=
min
(
u
,
v
);
out
[
v
]
=
min
(
u
,
v
);
break
;
}
}
if
(
!
has_unmatched_neighbor
)
cluster
[
u
]
=
u
;
}
}
template
<
typename
scalar_t
>
__global__
void
weighted_respond_kernel
(
int64_t
*
out
,
const
int64_t
*
proposal
,
const
int64_t
*
rowptr
,
const
int64_t
*
col
,
const
scalar_t
*
weight
,
int64_t
numel
)
{
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
if
(
out
[
u
]
!=
-
2
)
continue
;
// Only vist red nodes.
bool
has_unmatched_neighbor
=
false
;
int64_t
v_max
=
-
1
;
scalar_t
w_max
=
0
;
for
(
int64_t
i
=
rowptr
[
u
];
i
<
rowptr
[
u
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
out
[
v
]
<
0
)
has_unmatched_neighbor
=
true
;
// Unmatched neighbor found.
if
(
out
[
v
]
==
-
1
&&
proposal
[
v
]
==
u
&&
weight
[
i
]
>=
w_max
)
{
// Find maximum weighted blue neighbhor v which proposed to u.
v_max
=
v
;
w_max
=
weight
[
i
];
}
}
if
(
v_max
>=
0
)
{
out
[
u
]
=
min
(
u
,
v_max
);
// Match neighbors.
out
[
v_max
]
=
min
(
u
,
v_max
);
}
if
(
!
has_unmatched_neighbor
)
out
[
u
]
=
u
;
}
}
void
respond
(
torch
::
Tensor
out
,
torch
::
Tensor
proposal
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
!
optional_weight
.
has_value
())
{
respond_kernel
<<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
out
.
data_ptr
<
int64_t
>
(),
proposal
.
data_ptr
<
int64_t
>
(),
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
out
.
numel
());
}
else
{
auto
=
optional_weight
.
value
();
AT_DISPATCH_ALL_TYPES
(
weight
.
scalar_type
(),
"respond_kernel"
,
[
&
]
{
respond_kernel
<
scalar_t
><<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
out
.
data_ptr
<
int64_t
>
(),
proposal
.
data_ptr
<
int64_t
>
(),
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
out
.
numel
());
});
}
}
csrc/cuda/graclus_cuda.h
0 → 100644
View file @
5a485e98
#pragma once
#include <torch/extension.h>
torch
::
Tensor
graclus_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
);
csrc/cuda/graclus_kernel.cu
deleted
100644 → 0
View file @
06d9038f
#include <ATen/ATen.h>
#include "coloring.cuh"
#include "proposal.cuh"
#include "response.cuh"
#include "utils.cuh"
at
::
Tensor
graclus_cuda
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
num_nodes
)
{
cudaSetDevice
(
row
.
get_device
());
std
::
tie
(
row
,
col
)
=
remove_self_loops
(
row
,
col
);
std
::
tie
(
row
,
col
)
=
rand
(
row
,
col
);
std
::
tie
(
row
,
col
)
=
to_csr
(
row
,
col
,
num_nodes
);
auto
cluster
=
at
::
full
(
num_nodes
,
-
1
,
row
.
options
());
auto
proposal
=
at
::
full
(
num_nodes
,
-
1
,
row
.
options
());
while
(
!
colorize
(
cluster
))
{
propose
(
cluster
,
proposal
,
row
,
col
);
respond
(
cluster
,
proposal
,
row
,
col
);
}
return
cluster
;
}
at
::
Tensor
weighted_graclus_cuda
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
weight
,
int64_t
num_nodes
)
{
cudaSetDevice
(
row
.
get_device
());
std
::
tie
(
row
,
col
,
weight
)
=
remove_self_loops
(
row
,
col
,
weight
);
std
::
tie
(
row
,
col
,
weight
)
=
to_csr
(
row
,
col
,
weight
,
num_nodes
);
auto
cluster
=
at
::
full
(
num_nodes
,
-
1
,
row
.
options
());
auto
proposal
=
at
::
full
(
num_nodes
,
-
1
,
row
.
options
());
while
(
!
colorize
(
cluster
))
{
propose
(
cluster
,
proposal
,
row
,
col
,
weight
);
respond
(
cluster
,
proposal
,
row
,
col
,
weight
);
}
return
cluster
;
}
csrc/cuda/knn_cuda.cu
View file @
5a485e98
...
@@ -100,5 +100,5 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
...
@@ -100,5 +100,5 @@ torch::Tensor knn_cuda(torch::Tensor x, torch::Tensor y, torch::Tensor ptr_x,
});
});
auto
mask
=
col
!=
-
1
;
auto
mask
=
col
!=
-
1
;
return
a
t
::
stack
({
row
.
masked_select
(
mask
),
col
.
masked_select
(
mask
)},
0
);
return
t
orch
::
stack
({
row
.
masked_select
(
mask
),
col
.
masked_select
(
mask
)},
0
);
}
}
csrc/cuda/proposal.cuh
deleted
100644 → 0
View file @
06d9038f
#pragma once
#include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__
void
propose_kernel
(
int64_t
*
__restrict__
cluster
,
int64_t
*
proposal
,
int64_t
*
__restrict
row
,
int64_t
*
__restrict__
col
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int64_t
u
=
index
;
u
<
numel
;
u
+=
stride
)
{
if
(
cluster
[
u
]
!=
-
1
)
continue
;
// Only vist blue nodes.
bool
has_unmatched_neighbor
=
false
;
for
(
int64_t
i
=
row
[
u
];
i
<
row
[
u
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
cluster
[
v
]
<
0
)
has_unmatched_neighbor
=
true
;
// Unmatched neighbor found.
if
(
cluster
[
v
]
==
-
2
)
{
proposal
[
u
]
=
v
;
// Propose to first red neighbor.
break
;
}
}
if
(
!
has_unmatched_neighbor
)
cluster
[
u
]
=
u
;
}
}
void
propose
(
at
::
Tensor
cluster
,
at
::
Tensor
proposal
,
at
::
Tensor
row
,
at
::
Tensor
col
)
{
propose_kernel
<<<
BLOCKS
(
cluster
.
numel
()),
THREADS
>>>
(
cluster
.
DATA_PTR
<
int64_t
>
(),
proposal
.
DATA_PTR
<
int64_t
>
(),
row
.
DATA_PTR
<
int64_t
>
(),
col
.
DATA_PTR
<
int64_t
>
(),
cluster
.
numel
());
}
template
<
typename
scalar_t
>
__global__
void
propose_kernel
(
int64_t
*
__restrict__
cluster
,
int64_t
*
proposal
,
int64_t
*
__restrict
row
,
int64_t
*
__restrict__
col
,
scalar_t
*
__restrict__
weight
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int64_t
u
=
index
;
u
<
numel
;
u
+=
stride
)
{
if
(
cluster
[
u
]
!=
-
1
)
continue
;
// Only vist blue nodes.
bool
has_unmatched_neighbor
=
false
;
int64_t
v_max
=
-
1
;
scalar_t
w_max
=
0
;
for
(
int64_t
i
=
row
[
u
];
i
<
row
[
u
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
cluster
[
v
]
<
0
)
has_unmatched_neighbor
=
true
;
// Unmatched neighbor found.
// Find maximum weighted red neighbor.
if
(
cluster
[
v
]
==
-
2
&&
weight
[
i
]
>=
w_max
)
{
v_max
=
v
;
w_max
=
weight
[
i
];
}
}
proposal
[
u
]
=
v_max
;
// Propose.
if
(
!
has_unmatched_neighbor
)
cluster
[
u
]
=
u
;
}
}
void
propose
(
at
::
Tensor
cluster
,
at
::
Tensor
proposal
,
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
weight
)
{
AT_DISPATCH_ALL_TYPES
(
weight
.
scalar_type
(),
"propose_kernel"
,
[
&
]
{
propose_kernel
<
scalar_t
><<<
BLOCKS
(
cluster
.
numel
()),
THREADS
>>>
(
cluster
.
DATA_PTR
<
int64_t
>
(),
proposal
.
DATA_PTR
<
int64_t
>
(),
row
.
DATA_PTR
<
int64_t
>
(),
col
.
DATA_PTR
<
int64_t
>
(),
weight
.
DATA_PTR
<
scalar_t
>
(),
cluster
.
numel
());
});
}
csrc/cuda/response.cuh
deleted
100644 → 0
View file @
06d9038f
#pragma once
#include <ATen/ATen.h>
#include "compat.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__
void
respond_kernel
(
int64_t
*
__restrict__
cluster
,
int64_t
*
proposal
,
int64_t
*
__restrict
row
,
int64_t
*
__restrict__
col
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int64_t
u
=
index
;
u
<
numel
;
u
+=
stride
)
{
if
(
cluster
[
u
]
!=
-
2
)
continue
;
// Only vist red nodes.
bool
has_unmatched_neighbor
=
false
;
for
(
int64_t
i
=
row
[
u
];
i
<
row
[
u
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
cluster
[
v
]
<
0
)
has_unmatched_neighbor
=
true
;
// Unmatched neighbor found.
if
(
cluster
[
v
]
==
-
1
&&
proposal
[
v
]
==
u
)
{
// Match first blue neighbhor v which proposed to u.
cluster
[
u
]
=
min
(
u
,
v
);
cluster
[
v
]
=
min
(
u
,
v
);
break
;
}
}
if
(
!
has_unmatched_neighbor
)
cluster
[
u
]
=
u
;
}
}
void
respond
(
at
::
Tensor
cluster
,
at
::
Tensor
proposal
,
at
::
Tensor
row
,
at
::
Tensor
col
)
{
respond_kernel
<<<
BLOCKS
(
cluster
.
numel
()),
THREADS
>>>
(
cluster
.
DATA_PTR
<
int64_t
>
(),
proposal
.
DATA_PTR
<
int64_t
>
(),
row
.
DATA_PTR
<
int64_t
>
(),
col
.
DATA_PTR
<
int64_t
>
(),
cluster
.
numel
());
}
template
<
typename
scalar_t
>
__global__
void
respond_kernel
(
int64_t
*
__restrict__
cluster
,
int64_t
*
proposal
,
int64_t
*
__restrict
row
,
int64_t
*
__restrict__
col
,
scalar_t
*
__restrict__
weight
,
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
int64_t
u
=
index
;
u
<
numel
;
u
+=
stride
)
{
if
(
cluster
[
u
]
!=
-
2
)
continue
;
// Only vist red nodes.
bool
has_unmatched_neighbor
=
false
;
int64_t
v_max
=
-
1
;
scalar_t
w_max
=
0
;
for
(
int64_t
i
=
row
[
u
];
i
<
row
[
u
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
cluster
[
v
]
<
0
)
has_unmatched_neighbor
=
true
;
// Unmatched neighbor found.
if
(
cluster
[
v
]
==
-
1
&&
proposal
[
v
]
==
u
&&
weight
[
i
]
>=
w_max
)
{
// Find maximum weighted blue neighbhor v which proposed to u.
v_max
=
v
;
w_max
=
weight
[
i
];
}
}
if
(
v_max
>=
0
)
{
cluster
[
u
]
=
min
(
u
,
v_max
);
// Match neighbors.
cluster
[
v_max
]
=
min
(
u
,
v_max
);
}
if
(
!
has_unmatched_neighbor
)
cluster
[
u
]
=
u
;
}
}
void
respond
(
at
::
Tensor
cluster
,
at
::
Tensor
proposal
,
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
weight
)
{
AT_DISPATCH_ALL_TYPES
(
weight
.
scalar_type
(),
"respond_kernel"
,
[
&
]
{
respond_kernel
<
scalar_t
><<<
BLOCKS
(
cluster
.
numel
()),
THREADS
>>>
(
cluster
.
DATA_PTR
<
int64_t
>
(),
proposal
.
DATA_PTR
<
int64_t
>
(),
row
.
DATA_PTR
<
int64_t
>
(),
col
.
DATA_PTR
<
int64_t
>
(),
weight
.
DATA_PTR
<
scalar_t
>
(),
cluster
.
numel
());
});
}
csrc/cuda/utils.cuh
View file @
5a485e98
...
@@ -5,62 +5,3 @@
...
@@ -5,62 +5,3 @@
#define CHECK_CUDA(x) \
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
#define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch")
////////////////////////////////////////////////////////////////////////
#include <ATen/ATen.h>
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
remove_self_loops
(
at
::
Tensor
row
,
at
::
Tensor
col
)
{
auto
mask
=
row
!=
col
;
return
std
::
make_tuple
(
row
.
masked_select
(
mask
),
col
.
masked_select
(
mask
));
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
remove_self_loops
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
weight
)
{
auto
mask
=
row
!=
col
;
return
std
::
make_tuple
(
row
.
masked_select
(
mask
),
col
.
masked_select
(
mask
),
weight
.
masked_select
(
mask
));
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
rand
(
at
::
Tensor
row
,
at
::
Tensor
col
)
{
auto
perm
=
at
::
empty
(
row
.
size
(
0
),
row
.
options
());
at
::
randperm_out
(
perm
,
row
.
size
(
0
));
return
std
::
make_tuple
(
row
.
index_select
(
0
,
perm
),
col
.
index_select
(
0
,
perm
));
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
sort_by_row
(
at
::
Tensor
row
,
at
::
Tensor
col
)
{
at
::
Tensor
perm
;
std
::
tie
(
row
,
perm
)
=
row
.
sort
();
return
std
::
make_tuple
(
row
,
col
.
index_select
(
0
,
perm
));
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
sort_by_row
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
weight
)
{
at
::
Tensor
perm
;
std
::
tie
(
row
,
perm
)
=
row
.
sort
();
return
std
::
make_tuple
(
row
,
col
.
index_select
(
0
,
perm
),
weight
.
index_select
(
0
,
perm
));
}
at
::
Tensor
degree
(
at
::
Tensor
row
,
int64_t
num_nodes
)
{
auto
zero
=
at
::
zeros
(
num_nodes
,
row
.
options
());
auto
one
=
at
::
ones
(
row
.
size
(
0
),
row
.
options
());
return
zero
.
scatter_add_
(
0
,
row
,
one
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
to_csr
(
at
::
Tensor
row
,
at
::
Tensor
col
,
int64_t
num_nodes
)
{
std
::
tie
(
row
,
col
)
=
sort_by_row
(
row
,
col
);
row
=
degree
(
row
,
num_nodes
).
cumsum
(
0
);
row
=
at
::
cat
({
at
::
zeros
(
1
,
row
.
options
()),
row
},
0
);
// Prepend zero.
return
std
::
make_tuple
(
row
,
col
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
at
::
Tensor
>
to_csr
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
weight
,
int64_t
num_nodes
)
{
std
::
tie
(
row
,
col
,
weight
)
=
sort_by_row
(
row
,
col
,
weight
);
row
=
degree
(
row
,
num_nodes
).
cumsum
(
0
);
row
=
at
::
cat
({
at
::
zeros
(
1
,
row
.
options
()),
row
},
0
);
// Prepend zero.
return
std
::
make_tuple
(
row
,
col
,
weight
);
}
torch_cluster/graclus.py
View file @
5a485e98
...
@@ -32,7 +32,17 @@ def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
...
@@ -32,7 +32,17 @@ def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
if
num_nodes
is
None
:
if
num_nodes
is
None
:
num_nodes
=
max
(
int
(
row
.
max
()),
int
(
col
.
max
()))
+
1
num_nodes
=
max
(
int
(
row
.
max
()),
int
(
col
.
max
()))
+
1
perm
=
torch
.
argsort
(
row
*
num_nodes
+
col
)
# Remove self-loops.
mask
=
row
==
col
row
,
col
=
row
[
mask
],
col
[
mask
]
# Randomly shuffle nodes.
if
weight
is
not
None
:
perm
=
torch
.
randperm
(
row
.
size
(
0
),
device
=
row
.
device
)
row
,
col
=
row
[
perm
],
col
[
perm
]
# To CSR.
perm
=
torch
.
argsort
(
row
)
row
,
col
=
row
[
perm
],
col
[
perm
]
row
,
col
=
row
[
perm
],
col
[
perm
]
deg
=
row
.
new_zeros
(
num_nodes
)
deg
=
row
.
new_zeros
(
num_nodes
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment