Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
8d837c78
Commit
8d837c78
authored
Nov 13, 2020
by
rusty1s
Browse files
relabel one hop
parent
f0609836
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
130 additions
and
1 deletion
+130
-1
csrc/cpu/relabel_cpu.cpp
csrc/cpu/relabel_cpu.cpp
+99
-0
csrc/cpu/relabel_cpu.h
csrc/cpu/relabel_cpu.h
+6
-0
csrc/relabel.cpp
csrc/relabel.cpp
+19
-1
csrc/sparse.h
csrc/sparse.h
+6
-0
No files found.
csrc/cpu/relabel_cpu.cpp
View file @
8d837c78
...
@@ -41,3 +41,102 @@ std::tuple<torch::Tensor, torch::Tensor> relabel_cpu(torch::Tensor col,
...
@@ -41,3 +41,102 @@ std::tuple<torch::Tensor, torch::Tensor> relabel_cpu(torch::Tensor col,
return
std
::
make_tuple
(
out_col
,
out_idx
);
return
std
::
make_tuple
(
out_col
,
out_idx
);
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
>
relabel_one_hop_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
Tensor
idx
)
{
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
if
(
optional_value
.
has_value
())
{
CHECK_CPU
(
optional_value
.
value
());
CHECK_INPUT
(
optional_value
.
value
().
dim
()
==
1
);
}
CHECK_CPU
(
idx
);
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
auto
idx_data
=
idx
.
data_ptr
<
int64_t
>
();
std
::
vector
<
int64_t
>
n_ids
;
std
::
unordered_map
<
int64_t
,
int64_t
>
n_id_map
;
std
::
unordered_map
<
int64_t
,
int64_t
>::
iterator
it
;
auto
out_rowptr
=
torch
::
empty
(
idx
.
numel
()
+
1
,
rowptr
.
options
());
auto
out_rowptr_data
=
out_rowptr
.
data_ptr
<
int64_t
>
();
out_rowptr_data
[
0
]
=
0
;
int64_t
v
,
w
,
c
,
row_start
,
row_end
,
offset
=
0
;
for
(
int64_t
i
=
0
;
i
<
idx
.
numel
();
i
++
)
{
v
=
idx_data
[
i
];
n_id_map
[
i
]
=
v
;
offset
+=
rowptr_data
[
v
+
1
]
-
rowptr_data
[
v
];
out_rowptr_data
[
i
+
1
]
=
offset
;
}
auto
out_col
=
torch
::
empty
(
offset
,
col
.
options
());
auto
out_col_data
=
out_col
.
data_ptr
<
int64_t
>
();
torch
::
optional
<
torch
::
Tensor
>
out_value
=
torch
::
nullopt
;
if
(
optional_value
.
has_value
())
{
out_value
=
torch
::
empty
(
offset
,
optional_value
.
value
().
options
());
AT_DISPATCH_ALL_TYPES
(
optional_value
.
value
().
scalar_type
(),
"relabel"
,
[
&
]
{
auto
value_data
=
optional_value
.
value
().
data_ptr
<
scalar_t
>
();
auto
out_value_data
=
out_value
.
value
().
data_ptr
<
scalar_t
>
();
offset
=
0
;
for
(
int64_t
i
=
0
;
i
<
idx
.
numel
();
i
++
)
{
v
=
idx_data
[
i
];
row_start
=
rowptr_data
[
v
],
row_end
=
rowptr_data
[
v
+
1
];
for
(
int64_t
j
=
row_start
;
j
<
row_end
;
j
++
)
{
w
=
col_data
[
j
];
it
=
n_id_map
.
find
(
w
);
if
(
it
==
n_id_map
.
end
())
{
c
=
idx
.
numel
()
+
n_ids
.
size
();
n_id_map
[
w
]
=
c
;
n_ids
.
push_back
(
w
);
out_col_data
[
offset
]
=
c
;
}
else
{
out_col_data
[
offset
]
=
it
->
second
;
}
out_value_data
[
offset
]
=
value_data
[
j
];
offset
++
;
}
}
});
}
else
{
offset
=
0
;
for
(
int64_t
i
=
0
;
i
<
idx
.
numel
();
i
++
)
{
v
=
idx_data
[
i
];
row_start
=
rowptr_data
[
v
],
row_end
=
rowptr_data
[
v
+
1
];
for
(
int64_t
j
=
row_start
;
j
<
row_end
;
j
++
)
{
w
=
col_data
[
j
];
it
=
n_id_map
.
find
(
w
);
if
(
it
==
n_id_map
.
end
())
{
c
=
idx
.
numel
()
+
n_ids
.
size
();
n_id_map
[
w
]
=
c
;
n_ids
.
push_back
(
w
);
out_col_data
[
offset
]
=
c
;
}
else
{
out_col_data
[
offset
]
=
it
->
second
;
}
offset
++
;
}
}
}
out_rowptr
=
torch
::
cat
({
out_rowptr
,
torch
::
full
({(
int64_t
)
n_ids
.
size
()},
out_col
.
numel
(),
rowptr
.
options
())});
idx
=
torch
::
cat
(
{
idx
,
torch
::
from_blob
(
n_ids
.
data
(),
{(
int64_t
)
n_ids
.
size
()})});
return
std
::
make_tuple
(
out_rowptr
,
out_col
,
out_value
,
idx
);
}
csrc/cpu/relabel_cpu.h
View file @
8d837c78
...
@@ -4,3 +4,9 @@
...
@@ -4,3 +4,9 @@
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
relabel_cpu
(
torch
::
Tensor
col
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
relabel_cpu
(
torch
::
Tensor
col
,
torch
::
Tensor
idx
);
torch
::
Tensor
idx
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
>
relabel_one_hop_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
Tensor
idx
);
csrc/relabel.cpp
View file @
8d837c78
...
@@ -20,5 +20,23 @@ std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
...
@@ -20,5 +20,23 @@ std::tuple<torch::Tensor, torch::Tensor> relabel(torch::Tensor col,
}
}
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
>
relabel_one_hop
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
Tensor
idx
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
AT_ERROR
(
"No CUDA version supported"
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
relabel_one_hop_cpu
(
rowptr
,
col
,
optional_value
,
idx
);
}
}
static
auto
registry
=
static
auto
registry
=
torch
::
RegisterOperators
().
op
(
"torch_sparse::relabel"
,
&
relabel
);
torch
::
RegisterOperators
()
.
op
(
"torch_sparse::relabel"
,
&
relabel
)
.
op
(
"torch_sparse::relabel_one_hop"
,
&
relabel_one_hop
);
csrc/sparse.h
View file @
8d837c78
...
@@ -18,6 +18,12 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
...
@@ -18,6 +18,12 @@ torch::Tensor mt_partition(torch::Tensor rowptr, torch::Tensor col,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
relabel
(
torch
::
Tensor
col
,
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
relabel
(
torch
::
Tensor
col
,
torch
::
Tensor
idx
);
torch
::
Tensor
idx
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
optional
<
torch
::
Tensor
>
,
torch
::
Tensor
>
relabel_one_hop
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
optional
<
torch
::
Tensor
>
optional_value
,
torch
::
Tensor
idx
);
torch
::
Tensor
random_walk
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
random_walk
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
start
,
int64_t
walk_length
);
torch
::
Tensor
start
,
int64_t
walk_length
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment