Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
06df4d9b
Commit
06df4d9b
authored
Mar 13, 2020
by
rusty1s
Browse files
graclus fix
parent
8c8014b9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
66 deletions
+69
-66
csrc/cuda/graclus_cuda.cu
csrc/cuda/graclus_cuda.cu
+69
-66
No files found.
csrc/cuda/graclus_cuda.cu
View file @
06df4d9b
...
...
@@ -2,61 +2,37 @@
#include <ATen/cuda/CUDAContext.h>
#include "utils.h"
#include "utils.
cu
h"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
#define BLUE_P 0.53406
torch
::
Tensor
graclus_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_INPUT
(
rowptr
.
dim
()
==
1
&&
col
.
dim
()
==
1
);
if
(
optional_weight
.
has_value
())
{
CHECK_CUDA
(
optional_weight
.
value
());
CHECK_INPUT
(
optional_weight
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_weight
.
value
().
numel
()
==
col
.
numel
());
}
cudaSetDevice
(
rowptr
.
get_device
());
int64_t
num_nodes
=
rowptr
.
numel
()
-
1
;
auto
out
=
torch
::
full
(
num_nodes
,
-
1
,
rowptr
.
options
());
auto
proposal
=
torch
::
full
(
num_nodes
,
-
1
,
rowptr
.
options
());
while
(
!
colorize
(
out
))
{
propose
(
out
,
proposal
,
rowptr
,
col
,
optional_weight
);
respond
(
out
,
proposal
,
rowptr
,
col
,
optional_weight
);
}
return
out
;
}
__device__
int64_t
done_d
;
__global__
void
init_done_kernel
()
{
done_d
=
1
;
}
__device__
bool
done_d
;
__global__
void
init_done_kernel
()
{
done_d
=
true
;
}
__global__
void
colorize_kernel
(
int64_t
*
out
,
const
float
*
bernoulli
,
int64_t
numel
)
{
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
if
(
out
[
u
]
<
0
)
{
out
[
u
]
=
(
int64_t
)
bernoulli
[
u
]
-
2
;
done_d
=
0
;
if
(
out
[
thread_idx
]
<
0
)
{
out
[
thread_idx
]
=
(
int64_t
)
bernoulli
[
thread_idx
]
-
2
;
done_d
=
false
;
}
}
}
int64_t
colorize
(
torch
::
Tensor
out
)
{
bool
colorize
(
torch
::
Tensor
out
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
init_done_kernel
<<<
1
,
1
,
0
,
stream
>>>
();
auto
numel
=
cluster
.
size
(
0
);
auto
numel
=
out
.
size
(
0
);
auto
props
=
torch
::
full
(
numel
,
BLUE_P
,
out
.
options
().
dtype
(
torch
::
kFloat
));
auto
bernoulli
=
props
.
bernoulli
();
colorize_kernel
<<<
BLOCKS
(
numel
),
THREADS
,
0
,
stream
>>>
(
out
.
data_ptr
<
int64_t
>
(),
bernoulli
.
data_ptr
<
float
>
(),
numel
);
int64_t
done_h
;
bool
done_h
;
cudaMemcpyFromSymbol
(
&
done_h
,
done_d
,
sizeof
(
done_h
),
0
,
cudaMemcpyDeviceToHost
);
return
done_h
;
...
...
@@ -68,25 +44,25 @@ __global__ void propose_kernel(int64_t *out, int64_t *proposal,
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
if
(
out
[
u
]
!=
-
1
)
continue
;
// Only vist blue nodes.
if
(
out
[
thread_idx
]
!=
-
1
)
return
;
// Only vist blue nodes.
bool
has_unmatched_neighbor
=
false
;
for
(
int64_t
i
=
rowptr
[
u
];
i
<
rowptr
[
u
+
1
];
i
++
)
{
for
(
int64_t
i
=
rowptr
[
thread_idx
];
i
<
rowptr
[
thread_idx
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
out
[
v
]
<
0
)
has_unmatched_neighbor
=
true
;
// Unmatched neighbor found.
if
(
out
[
v
]
==
-
2
)
{
proposal
[
u
]
=
v
;
// Propose to first red neighbor.
proposal
[
thread_idx
]
=
v
;
// Propose to first red neighbor.
break
;
}
}
if
(
!
has_unmatched_neighbor
)
out
[
u
]
=
u
;
out
[
thread_idx
]
=
thread_idx
;
}
}
...
...
@@ -98,14 +74,14 @@ __global__ void weighted_propose_kernel(int64_t *out, int64_t *proposal,
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
if
(
out
[
u
]
!=
-
1
)
continue
;
// Only vist blue nodes.
if
(
out
[
thread_idx
]
!=
-
1
)
return
;
// Only vist blue nodes.
bool
has_unmatched_neighbor
=
false
;
int64_t
v_max
=
-
1
;
scalar_t
w_max
=
0
;
for
(
int64_t
i
=
rowptr
[
u
];
i
<
rowptr
[
u
+
1
];
i
++
)
{
for
(
int64_t
i
=
rowptr
[
thread_idx
];
i
<
rowptr
[
thread_idx
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
out
[
v
]
<
0
)
...
...
@@ -118,24 +94,25 @@ __global__ void weighted_propose_kernel(int64_t *out, int64_t *proposal,
}
}
proposal
[
u
]
=
v_max
;
// Propose.
proposal
[
thread_idx
]
=
v_max
;
// Propose.
if
(
!
has_unmatched_neighbor
)
out
[
u
]
=
u
;
out
[
thread_idx
]
=
thread_idx
;
}
}
void
propose
(
torch
::
Tensor
out
,
torch
::
Tensor
proposal
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
!
optional_weight
.
has_value
())
{
propose_kernel
<<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
out
.
data_ptr
<
int64_t
>
(),
proposal
.
data_ptr
<
int64_t
>
(),
row
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
out
.
numel
());
row
ptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
out
.
numel
());
}
else
{
auto
=
optional_weight
.
value
();
auto
weight
=
optional_weight
.
value
();
AT_DISPATCH_ALL_TYPES
(
weight
.
scalar_type
(),
"propose_kernel"
,
[
&
]
{
weighted_propose_kernel
<
scalar_t
>
<<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
...
...
@@ -151,27 +128,27 @@ __global__ void respond_kernel(int64_t *out, const int64_t *proposal,
int64_t
numel
)
{
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
if
(
out
[
u
]
!=
-
2
)
continue
;
// Only vist red nodes.
if
(
out
[
thread_idx
]
!=
-
2
)
return
;
// Only vist red nodes.
bool
has_unmatched_neighbor
=
false
;
for
(
int64_t
i
=
rowptr
[
u
];
i
<
rowptr
[
u
+
1
];
i
++
)
{
for
(
int64_t
i
=
rowptr
[
thread_idx
];
i
<
rowptr
[
thread_idx
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
out
[
v
]
<
0
)
has_unmatched_neighbor
=
true
;
// Unmatched neighbor found.
if
(
out
[
v
]
==
-
1
&&
proposal
[
v
]
==
u
)
{
if
(
out
[
v
]
==
-
1
&&
proposal
[
v
]
==
thread_idx
)
{
// Match first blue neighbhor v which proposed to u.
out
[
u
]
=
min
(
u
,
v
);
out
[
v
]
=
min
(
u
,
v
);
out
[
thread_idx
]
=
min
(
thread_idx
,
v
);
out
[
v
]
=
min
(
thread_idx
,
v
);
break
;
}
}
if
(
!
has_unmatched_neighbor
)
cluster
[
u
]
=
u
;
out
[
thread_idx
]
=
thread_idx
;
}
}
...
...
@@ -182,20 +159,20 @@ __global__ void weighted_respond_kernel(int64_t *out, const int64_t *proposal,
const
scalar_t
*
weight
,
int64_t
numel
)
{
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
if
(
out
[
u
]
!=
-
2
)
continue
;
// Only vist red nodes.
if
(
out
[
thread_idx
]
!=
-
2
)
return
;
// Only vist red nodes.
bool
has_unmatched_neighbor
=
false
;
int64_t
v_max
=
-
1
;
scalar_t
w_max
=
0
;
for
(
int64_t
i
=
rowptr
[
u
];
i
<
rowptr
[
u
+
1
];
i
++
)
{
for
(
int64_t
i
=
rowptr
[
thread_idx
];
i
<
rowptr
[
thread_idx
+
1
];
i
++
)
{
auto
v
=
col
[
i
];
if
(
out
[
v
]
<
0
)
has_unmatched_neighbor
=
true
;
// Unmatched neighbor found.
if
(
out
[
v
]
==
-
1
&&
proposal
[
v
]
==
u
&&
weight
[
i
]
>=
w_max
)
{
if
(
out
[
v
]
==
-
1
&&
proposal
[
v
]
==
thread_idx
&&
weight
[
i
]
>=
w_max
)
{
// Find maximum weighted blue neighbhor v which proposed to u.
v_max
=
v
;
w_max
=
weight
[
i
];
...
...
@@ -203,17 +180,18 @@ __global__ void weighted_respond_kernel(int64_t *out, const int64_t *proposal,
}
if
(
v_max
>=
0
)
{
out
[
u
]
=
min
(
u
,
v_max
);
// Match neighbors.
out
[
v_max
]
=
min
(
u
,
v_max
);
out
[
thread_idx
]
=
min
(
thread_idx
,
v_max
);
// Match neighbors.
out
[
v_max
]
=
min
(
thread_idx
,
v_max
);
}
if
(
!
has_unmatched_neighbor
)
out
[
u
]
=
u
;
out
[
thread_idx
]
=
thread_idx
;
}
}
void
respond
(
torch
::
Tensor
out
,
torch
::
Tensor
proposal
,
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
@@ -222,12 +200,37 @@ void respond(torch::Tensor out, torch::Tensor proposal, torch::Tensor rowptr,
out
.
data_ptr
<
int64_t
>
(),
proposal
.
data_ptr
<
int64_t
>
(),
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
out
.
numel
());
}
else
{
auto
=
optional_weight
.
value
();
auto
weight
=
optional_weight
.
value
();
AT_DISPATCH_ALL_TYPES
(
weight
.
scalar_type
(),
"respond_kernel"
,
[
&
]
{
respond_kernel
<
scalar_t
><<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
out
.
data_ptr
<
int64_t
>
(),
proposal
.
data_ptr
<
int64_t
>
(),
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
out
.
numel
());
weighted_respond_kernel
<
scalar_t
>
<<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
out
.
data_ptr
<
int64_t
>
(),
proposal
.
data_ptr
<
int64_t
>
(),
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
out
.
numel
());
});
}
}
torch
::
Tensor
graclus_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_weight
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_INPUT
(
rowptr
.
dim
()
==
1
&&
col
.
dim
()
==
1
);
if
(
optional_weight
.
has_value
())
{
CHECK_CUDA
(
optional_weight
.
value
());
CHECK_INPUT
(
optional_weight
.
value
().
dim
()
==
1
);
CHECK_INPUT
(
optional_weight
.
value
().
numel
()
==
col
.
numel
());
}
cudaSetDevice
(
rowptr
.
get_device
());
int64_t
num_nodes
=
rowptr
.
numel
()
-
1
;
auto
out
=
torch
::
full
(
num_nodes
,
-
1
,
rowptr
.
options
());
auto
proposal
=
torch
::
full
(
num_nodes
,
-
1
,
rowptr
.
options
());
while
(
!
colorize
(
out
))
{
propose
(
out
,
proposal
,
rowptr
,
col
,
optional_weight
);
respond
(
out
,
proposal
,
rowptr
,
col
,
optional_weight
);
}
return
out
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment