Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
288cfd44
Commit
288cfd44
authored
Nov 24, 2020
by
rusty1s
Browse files
add bipartite flag
parent
c493caaf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
7 deletions
+8
-7
csrc/cpu/relabel_cpu.cpp
csrc/cpu/relabel_cpu.cpp
+5
-4
csrc/cpu/relabel_cpu.h
csrc/cpu/relabel_cpu.h
+1
-1
csrc/relabel.cpp
csrc/relabel.cpp
+2
-2
No files found.
csrc/cpu/relabel_cpu.cpp
View file @
288cfd44
...
@@ -46,7 +46,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
...
@@ -46,7 +46,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch
::
Tensor
>
torch
::
Tensor
>
relabel_one_hop_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
relabel_one_hop_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
Tensor
idx
)
{
torch
::
Tensor
idx
,
bool
bipartite
)
{
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
col
);
...
@@ -131,9 +131,10 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
...
@@ -131,9 +131,10 @@ relabel_one_hop_cpu(torch::Tensor rowptr, torch::Tensor col,
}
}
}
}
out_rowptr
=
if
(
bipartite
)
torch
::
cat
({
out_rowptr
,
torch
::
full
({(
int64_t
)
n_ids
.
size
()},
out_rowptr
=
torch
::
cat
(
out_col
.
numel
(),
rowptr
.
options
())});
{
out_rowptr
,
torch
::
full
({(
int64_t
)
n_ids
.
size
()},
out_col
.
numel
(),
rowptr
.
options
())});
idx
=
torch
::
cat
({
idx
,
torch
::
from_blob
(
n_ids
.
data
(),
{(
int64_t
)
n_ids
.
size
()},
idx
=
torch
::
cat
({
idx
,
torch
::
from_blob
(
n_ids
.
data
(),
{(
int64_t
)
n_ids
.
size
()},
idx
.
options
())});
idx
.
options
())});
...
...
csrc/cpu/relabel_cpu.h
View file @
288cfd44
...
@@ -9,4 +9,4 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
...
@@ -9,4 +9,4 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch
::
Tensor
>
torch
::
Tensor
>
relabel_one_hop_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
relabel_one_hop_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
Tensor
idx
);
torch
::
Tensor
idx
,
bool
bipartite
);
csrc/relabel.cpp
View file @
288cfd44
...
@@ -24,7 +24,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
...
@@ -24,7 +24,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::optional<torch::Tensor>,
torch
::
Tensor
>
torch
::
Tensor
>
relabel_one_hop
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
relabel_one_hop
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
Tensor
idx
)
{
torch
::
Tensor
idx
,
bool
bipartite
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
#ifdef WITH_CUDA
AT_ERROR
(
"No CUDA version supported"
);
AT_ERROR
(
"No CUDA version supported"
);
...
@@ -32,7 +32,7 @@ relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
...
@@ -32,7 +32,7 @@ relabel_one_hop(torch::Tensor rowptr, torch::Tensor col,
AT_ERROR
(
"Not compiled with CUDA support"
);
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
#endif
}
else
{
}
else
{
return
relabel_one_hop_cpu
(
rowptr
,
col
,
optional_value
,
idx
);
return
relabel_one_hop_cpu
(
rowptr
,
col
,
optional_value
,
idx
,
bipartite
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment