Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
014c4bae
Commit
014c4bae
authored
Aug 10, 2021
by
rusty1s
Browse files
hetero neighbor sampling
parent
9532032e
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
460 additions
and
11 deletions
+460
-11
csrc/cpu/hgt_sample_cpu.cpp
csrc/cpu/hgt_sample_cpu.cpp
+1
-1
csrc/cpu/neighbor_sample_cpu.cpp
csrc/cpu/neighbor_sample_cpu.cpp
+333
-0
csrc/cpu/neighbor_sample_cpu.h
csrc/cpu/neighbor_sample_cpu.h
+24
-0
csrc/cpu/utils.h
csrc/cpu/utils.h
+57
-9
csrc/hgt_sample.cpp
csrc/hgt_sample.cpp
+1
-0
csrc/neighbor_sample.cpp
csrc/neighbor_sample.cpp
+42
-0
torch_sparse/__init__.py
torch_sparse/__init__.py
+2
-1
No files found.
csrc/cpu/hgt_sample_cpu.cpp
View file @
014c4bae
...
@@ -102,7 +102,7 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
...
@@ -102,7 +102,7 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &colptr_dict,
const
int64_t
num_hops
)
{
const
int64_t
num_hops
)
{
// Create a mapping to convert single string relations to edge type triplets:
// Create a mapping to convert single string relations to edge type triplets:
std
::
unordered_map
<
rel_t
,
edge_t
>
to_edge_type
;
unordered_map
<
rel_t
,
edge_t
>
to_edge_type
;
for
(
const
auto
&
kv
:
colptr_dict
)
{
for
(
const
auto
&
kv
:
colptr_dict
)
{
const
auto
&
rel_type
=
kv
.
key
();
const
auto
&
rel_type
=
kv
.
key
();
to_edge_type
[
rel_type
]
=
split
(
rel_type
);
to_edge_type
[
rel_type
]
=
split
(
rel_type
);
...
...
csrc/cpu/neighbor_sample_cpu.cpp
0 → 100644
View file @
014c4bae
#include "neighbor_sample_cpu.h"
#include "utils.h"
using
namespace
std
;
namespace
{
template
<
bool
replace
,
bool
directed
>
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
sample
(
const
torch
::
Tensor
&
colptr
,
const
torch
::
Tensor
&
row
,
const
torch
::
Tensor
&
input_node
,
const
vector
<
int64_t
>
num_neighbors
)
{
// Initialize some data structures for the sampling process:
vector
<
int64_t
>
samples
;
unordered_map
<
int64_t
,
int64_t
>
to_local_node
;
auto
*
colptr_data
=
colptr
.
data_ptr
<
int64_t
>
();
auto
*
row_data
=
row
.
data_ptr
<
int64_t
>
();
auto
*
input_node_data
=
input_node
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
input_node
.
numel
();
i
++
)
{
const
auto
&
v
=
input_node_data
[
i
];
samples
.
push_back
(
v
);
to_local_node
.
insert
({
v
,
i
});
}
vector
<
int64_t
>
rows
,
cols
,
edges
;
int64_t
begin
=
0
,
end
=
samples
.
size
();
for
(
int64_t
ell
=
0
;
ell
<
(
int64_t
)
num_neighbors
.
size
();
ell
++
)
{
const
auto
&
num_samples
=
num_neighbors
[
ell
];
for
(
int64_t
i
=
begin
;
i
<
end
;
i
++
)
{
const
auto
&
w
=
samples
[
i
];
const
auto
&
col_start
=
colptr_data
[
w
];
const
auto
&
col_end
=
colptr_data
[
w
+
1
];
const
auto
col_count
=
col_end
-
col_start
;
if
(
col_count
==
0
)
continue
;
if
(
replace
)
{
for
(
int64_t
j
=
0
;
j
<
num_samples
;
j
++
)
{
const
int64_t
offset
=
col_start
+
rand
()
%
col_count
;
const
int64_t
&
v
=
row_data
[
offset
];
const
auto
res
=
to_local_node
.
insert
({
v
,
samples
.
size
()});
if
(
res
.
second
)
samples
.
push_back
(
v
);
if
(
directed
)
{
cols
.
push_back
(
i
);
rows
.
push_back
(
res
.
first
->
second
);
edges
.
push_back
(
offset
);
}
}
}
else
if
(
num_samples
>=
col_count
)
{
for
(
int64_t
offset
=
col_start
;
offset
<
col_end
;
offset
++
)
{
const
int64_t
&
v
=
row_data
[
offset
];
const
auto
res
=
to_local_node
.
insert
({
v
,
samples
.
size
()});
if
(
res
.
second
)
samples
.
push_back
(
v
);
if
(
directed
)
{
cols
.
push_back
(
i
);
rows
.
push_back
(
res
.
first
->
second
);
edges
.
push_back
(
offset
);
}
}
}
else
{
unordered_set
<
int64_t
>
rnd_indices
;
for
(
int64_t
j
=
col_count
-
num_samples
;
j
<
col_count
;
j
++
)
{
int64_t
rnd
=
rand
()
%
j
;
if
(
!
rnd_indices
.
insert
(
rnd
).
second
)
{
rnd
=
j
;
rnd_indices
.
insert
(
j
);
}
const
int64_t
offset
=
col_start
+
rnd
;
const
int64_t
&
v
=
row_data
[
offset
];
const
auto
res
=
to_local_node
.
insert
({
v
,
samples
.
size
()});
if
(
res
.
second
)
samples
.
push_back
(
v
);
if
(
directed
)
{
cols
.
push_back
(
i
);
rows
.
push_back
(
res
.
first
->
second
);
edges
.
push_back
(
offset
);
}
}
}
}
begin
=
end
,
end
=
samples
.
size
();
}
if
(
!
directed
)
{
unordered_map
<
int64_t
,
int64_t
>::
iterator
iter
;
for
(
int64_t
i
=
0
;
i
<
(
int64_t
)
samples
.
size
();
i
++
)
{
const
auto
&
w
=
samples
[
i
];
const
auto
&
col_start
=
colptr_data
[
w
];
const
auto
&
col_end
=
colptr_data
[
w
+
1
];
for
(
int64_t
offset
=
col_start
;
offset
<
col_end
;
offset
++
)
{
const
auto
&
v
=
row_data
[
offset
];
iter
=
to_local_node
.
find
(
v
);
if
(
iter
!=
to_local_node
.
end
())
{
rows
.
push_back
(
iter
->
second
);
cols
.
push_back
(
i
);
edges
.
push_back
(
offset
);
}
}
}
}
return
make_tuple
(
from_vector
<
int64_t
>
(
samples
),
from_vector
<
int64_t
>
(
rows
),
from_vector
<
int64_t
>
(
cols
),
from_vector
<
int64_t
>
(
edges
));
}
template
<
bool
replace
,
bool
directed
>
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hetero_sample
(
const
std
::
vector
<
node_t
>
&
node_types
,
const
std
::
vector
<
edge_t
>
&
edge_types
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
colptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
rel_t
,
std
::
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
int64_t
num_hops
)
{
// Create a mapping to convert single string relations to edge type triplets:
unordered_map
<
rel_t
,
edge_t
>
to_edge_type
;
for
(
const
auto
&
k
:
edge_types
)
to_edge_type
[
get
<
0
>
(
k
)
+
"__"
+
get
<
1
>
(
k
)
+
"__"
+
get
<
2
>
(
k
)]
=
k
;
// Initialize some data structures for the sampling process:
unordered_map
<
node_t
,
vector
<
int64_t
>>
samples_dict
;
unordered_map
<
node_t
,
unordered_map
<
int64_t
,
int64_t
>>
to_local_node_dict
;
for
(
const
auto
&
k
:
node_types
)
{
samples_dict
[
k
];
to_local_node_dict
[
k
];
}
unordered_map
<
rel_t
,
vector
<
int64_t
>>
rows_dict
,
cols_dict
,
edges_dict
;
for
(
const
auto
&
kv
:
colptr_dict
)
{
const
auto
&
rel_type
=
kv
.
key
();
rows_dict
[
rel_type
];
cols_dict
[
rel_type
];
edges_dict
[
rel_type
];
}
// Add the input nodes to the output nodes:
for
(
const
auto
&
kv
:
input_node_dict
)
{
const
auto
&
node_type
=
kv
.
key
();
const
auto
&
input_node
=
kv
.
value
();
const
auto
*
input_node_data
=
input_node
.
data_ptr
<
int64_t
>
();
auto
&
samples
=
samples_dict
.
at
(
node_type
);
auto
&
to_local_node
=
to_local_node_dict
.
at
(
node_type
);
for
(
int64_t
i
=
0
;
i
<
input_node
.
numel
();
i
++
)
{
const
auto
&
v
=
input_node_data
[
i
];
samples
.
push_back
(
v
);
to_local_node
.
insert
({
v
,
i
});
}
}
unordered_map
<
node_t
,
pair
<
int64_t
,
int64_t
>>
slice_dict
;
for
(
const
auto
&
kv
:
samples_dict
)
slice_dict
[
kv
.
first
]
=
{
0
,
kv
.
second
.
size
()};
for
(
int64_t
ell
=
0
;
ell
<
num_hops
;
ell
++
)
{
for
(
const
auto
&
kv
:
num_neighbors_dict
)
{
const
auto
&
rel_type
=
kv
.
key
();
const
auto
&
edge_type
=
to_edge_type
[
rel_type
];
const
auto
&
src_node_type
=
get
<
0
>
(
edge_type
);
const
auto
&
dst_node_type
=
get
<
2
>
(
edge_type
);
const
auto
&
num_samples
=
kv
.
value
()[
ell
];
const
auto
&
dst_samples
=
samples_dict
.
at
(
dst_node_type
);
auto
&
src_samples
=
samples_dict
.
at
(
src_node_type
);
auto
&
to_local_src_node
=
to_local_node_dict
.
at
(
src_node_type
);
const
auto
*
colptr_data
=
colptr_dict
.
at
(
rel_type
).
data_ptr
<
int64_t
>
();
const
auto
*
row_data
=
row_dict
.
at
(
rel_type
).
data_ptr
<
int64_t
>
();
auto
&
rows
=
rows_dict
.
at
(
rel_type
);
auto
&
cols
=
cols_dict
.
at
(
rel_type
);
auto
&
edges
=
edges_dict
.
at
(
rel_type
);
const
auto
&
begin
=
slice_dict
.
at
(
dst_node_type
).
first
;
const
auto
&
end
=
slice_dict
.
at
(
dst_node_type
).
second
;
for
(
int64_t
i
=
begin
;
i
<
end
;
i
++
)
{
const
auto
&
w
=
dst_samples
[
i
];
const
auto
&
col_start
=
colptr_data
[
w
];
const
auto
&
col_end
=
colptr_data
[
w
+
1
];
const
auto
col_count
=
col_end
-
col_start
;
if
(
col_count
==
0
)
continue
;
if
(
replace
)
{
for
(
int64_t
j
=
0
;
j
<
num_samples
;
j
++
)
{
const
int64_t
offset
=
col_start
+
rand
()
%
col_count
;
const
int64_t
&
v
=
row_data
[
offset
];
const
auto
res
=
to_local_src_node
.
insert
({
v
,
src_samples
.
size
()});
if
(
res
.
second
)
src_samples
.
push_back
(
v
);
if
(
directed
)
{
cols
.
push_back
(
i
);
rows
.
push_back
(
res
.
first
->
second
);
edges
.
push_back
(
offset
);
}
}
}
else
if
(
num_samples
>=
col_count
)
{
for
(
int64_t
offset
=
col_start
;
offset
<
col_end
;
offset
++
)
{
const
int64_t
&
v
=
row_data
[
offset
];
const
auto
res
=
to_local_src_node
.
insert
({
v
,
src_samples
.
size
()});
if
(
res
.
second
)
src_samples
.
push_back
(
v
);
if
(
directed
)
{
cols
.
push_back
(
i
);
rows
.
push_back
(
res
.
first
->
second
);
edges
.
push_back
(
offset
);
}
}
}
else
{
unordered_set
<
int64_t
>
rnd_indices
;
for
(
int64_t
j
=
col_count
-
num_samples
;
j
<
col_count
;
j
++
)
{
int64_t
rnd
=
rand
()
%
j
;
if
(
!
rnd_indices
.
insert
(
rnd
).
second
)
{
rnd
=
j
;
rnd_indices
.
insert
(
j
);
}
const
int64_t
offset
=
col_start
+
rnd
;
const
int64_t
&
v
=
row_data
[
offset
];
const
auto
res
=
to_local_src_node
.
insert
({
v
,
src_samples
.
size
()});
if
(
res
.
second
)
src_samples
.
push_back
(
v
);
if
(
directed
)
{
cols
.
push_back
(
i
);
rows
.
push_back
(
res
.
first
->
second
);
edges
.
push_back
(
offset
);
}
}
}
}
}
for
(
const
auto
&
kv
:
samples_dict
)
{
slice_dict
[
kv
.
first
]
=
{
slice_dict
.
at
(
kv
.
first
).
second
,
kv
.
second
.
size
()};
}
}
if
(
!
directed
)
{
// Construct the subgraph among the sampled nodes:
unordered_map
<
int64_t
,
int64_t
>::
iterator
iter
;
for
(
const
auto
&
kv
:
colptr_dict
)
{
const
auto
&
rel_type
=
kv
.
key
();
const
auto
&
edge_type
=
to_edge_type
[
rel_type
];
const
auto
&
src_node_type
=
get
<
0
>
(
edge_type
);
const
auto
&
dst_node_type
=
get
<
2
>
(
edge_type
);
const
auto
&
dst_samples
=
samples_dict
.
at
(
dst_node_type
);
auto
&
to_local_src_node
=
to_local_node_dict
.
at
(
src_node_type
);
const
auto
*
colptr_data
=
kv
.
value
().
data_ptr
<
int64_t
>
();
const
auto
*
row_data
=
row_dict
.
at
(
rel_type
).
data_ptr
<
int64_t
>
();
auto
&
rows
=
rows_dict
.
at
(
rel_type
);
auto
&
cols
=
cols_dict
.
at
(
rel_type
);
auto
&
edges
=
edges_dict
.
at
(
rel_type
);
for
(
int64_t
i
=
0
;
i
<
(
int64_t
)
dst_samples
.
size
();
i
++
)
{
const
auto
&
w
=
dst_samples
[
i
];
const
auto
&
col_start
=
colptr_data
[
w
];
const
auto
&
col_end
=
colptr_data
[
w
+
1
];
for
(
int64_t
offset
=
col_start
;
offset
<
col_end
;
offset
++
)
{
const
auto
&
v
=
row_data
[
offset
];
iter
=
to_local_src_node
.
find
(
v
);
if
(
iter
!=
to_local_src_node
.
end
())
{
rows
.
push_back
(
iter
->
second
);
cols
.
push_back
(
i
);
edges
.
push_back
(
offset
);
}
}
}
}
}
return
make_tuple
(
from_vector
<
node_t
,
int64_t
>
(
samples_dict
),
from_vector
<
rel_t
,
int64_t
>
(
rows_dict
),
from_vector
<
rel_t
,
int64_t
>
(
cols_dict
),
from_vector
<
rel_t
,
int64_t
>
(
edges_dict
));
}
}
// namespace
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
neighbor_sample_cpu
(
const
torch
::
Tensor
&
colptr
,
const
torch
::
Tensor
&
row
,
const
torch
::
Tensor
&
input_node
,
const
vector
<
int64_t
>
num_neighbors
,
const
bool
replace
,
const
bool
directed
)
{
if
(
replace
&&
directed
)
{
return
sample
<
true
,
true
>
(
colptr
,
row
,
input_node
,
num_neighbors
);
}
else
if
(
replace
&&
!
directed
)
{
return
sample
<
true
,
false
>
(
colptr
,
row
,
input_node
,
num_neighbors
);
}
else
if
(
!
replace
&&
directed
)
{
return
sample
<
false
,
true
>
(
colptr
,
row
,
input_node
,
num_neighbors
);
}
else
{
return
sample
<
false
,
false
>
(
colptr
,
row
,
input_node
,
num_neighbors
);
}
}
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hetero_neighbor_sample_cpu
(
const
std
::
vector
<
node_t
>
&
node_types
,
const
std
::
vector
<
edge_t
>
&
edge_types
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
colptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
rel_t
,
std
::
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
int64_t
num_hops
,
const
bool
replace
,
const
bool
directed
)
{
if
(
replace
&&
directed
)
{
return
hetero_sample
<
true
,
true
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
);
}
else
if
(
replace
&&
!
directed
)
{
return
hetero_sample
<
true
,
false
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
);
}
else
if
(
!
replace
&&
directed
)
{
return
hetero_sample
<
false
,
true
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
);
}
else
{
return
hetero_sample
<
false
,
false
>
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
);
}
}
csrc/cpu/neighbor_sample_cpu.h
0 → 100644
View file @
014c4bae
#pragma once
#include <torch/extension.h>
typedef
std
::
string
node_t
;
typedef
std
::
tuple
<
std
::
string
,
std
::
string
,
std
::
string
>
edge_t
;
typedef
std
::
string
rel_t
;
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
neighbor_sample_cpu
(
const
torch
::
Tensor
&
colptr
,
const
torch
::
Tensor
&
row
,
const
torch
::
Tensor
&
input_node
,
const
std
::
vector
<
int64_t
>
num_neighbors
,
const
bool
replace
,
const
bool
directed
);
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hetero_neighbor_sample_cpu
(
const
std
::
vector
<
node_t
>
&
node_types
,
const
std
::
vector
<
edge_t
>
&
edge_types
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
colptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
rel_t
,
std
::
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
int64_t
num_hops
,
const
bool
replace
,
const
bool
directed
);
csrc/cpu/utils.h
View file @
014c4bae
...
@@ -25,10 +25,23 @@ inline torch::Tensor from_vector(const std::vector<scalar_t> &vec,
...
@@ -25,10 +25,23 @@ inline torch::Tensor from_vector(const std::vector<scalar_t> &vec,
return
inplace
?
out
:
out
.
clone
();
return
inplace
?
out
:
out
.
clone
();
}
}
template
<
typename
key_t
,
typename
scalar_t
>
inline
c10
::
Dict
<
key_t
,
torch
::
Tensor
>
from_vector
(
const
std
::
unordered_map
<
key_t
,
std
::
vector
<
scalar_t
>>
&
vec_dict
,
bool
inplace
=
false
)
{
c10
::
Dict
<
key_t
,
torch
::
Tensor
>
out_dict
;
for
(
const
auto
&
kv
:
vec_dict
)
out_dict
.
insert
(
kv
.
first
,
from_vector
<
scalar_t
>
(
kv
.
second
,
inplace
));
return
out_dict
;
}
inline
torch
::
Tensor
inline
torch
::
Tensor
choice
(
int64_t
population
,
int64_t
num_samples
,
bool
replace
=
false
,
choice
(
int64_t
population
,
int64_t
num_samples
,
bool
replace
=
false
,
torch
::
optional
<
torch
::
Tensor
>
weight
=
torch
::
nullopt
)
{
torch
::
optional
<
torch
::
Tensor
>
weight
=
torch
::
nullopt
)
{
if
(
population
==
0
||
num_samples
==
0
)
return
torch
::
empty
({
0
},
at
::
kLong
);
if
(
!
replace
&&
num_samples
>=
population
)
if
(
!
replace
&&
num_samples
>=
population
)
return
torch
::
arange
(
population
,
at
::
kLong
);
return
torch
::
arange
(
population
,
at
::
kLong
);
...
@@ -47,18 +60,53 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
...
@@ -47,18 +60,53 @@ choice(int64_t population, int64_t num_samples, bool replace = false,
// Sample without replacement via Robert Floyd algorithm:
// Sample without replacement via Robert Floyd algorithm:
// https://www.nowherenearithaca.com/2013/05/
// https://www.nowherenearithaca.com/2013/05/
// robert-floyds-tiny-and-beautiful.html
// robert-floyds-tiny-and-beautiful.html
std
::
unordered_set
<
int64_t
>
values
;
for
(
int64_t
i
=
population
-
num_samples
;
i
<
population
;
i
++
)
{
if
(
!
values
.
insert
(
rand
()
%
i
).
second
)
values
.
insert
(
i
);
}
const
auto
out
=
torch
::
empty
(
num_samples
,
at
::
kLong
);
const
auto
out
=
torch
::
empty
(
num_samples
,
at
::
kLong
);
auto
*
out_data
=
out
.
data_ptr
<
int64_t
>
();
auto
*
out_data
=
out
.
data_ptr
<
int64_t
>
();
int64_t
i
=
0
;
std
::
unordered_set
<
int64_t
>
samples
;
for
(
const
auto
&
value
:
values
)
{
for
(
int64_t
i
=
population
-
num_samples
;
i
<
population
;
i
++
)
{
out_data
[
i
]
=
value
;
int64_t
sample
=
rand
()
%
i
;
i
++
;
if
(
!
samples
.
insert
(
sample
).
second
)
{
sample
=
i
;
samples
.
insert
(
sample
);
}
out_data
[
i
-
population
+
num_samples
]
=
sample
;
}
}
return
out
;
return
out
;
}
}
}
}
template
<
bool
replace
>
inline
void
uniform_choice
(
const
int64_t
population
,
const
int64_t
num_samples
,
const
int64_t
*
idx_data
,
std
::
vector
<
int64_t
>
*
samples
,
std
::
unordered_map
<
int64_t
,
int64_t
>
*
to_local_node
)
{
if
(
population
==
0
||
num_samples
==
0
)
return
;
if
(
replace
)
{
for
(
int64_t
i
=
0
;
i
<
num_samples
;
i
++
)
{
const
int64_t
&
v
=
idx_data
[
rand
()
%
population
];
if
(
to_local_node
->
insert
({
v
,
samples
->
size
()}).
second
)
samples
->
push_back
(
v
);
}
}
else
if
(
num_samples
>=
population
)
{
for
(
int64_t
i
=
0
;
i
<
population
;
i
++
)
{
const
int64_t
&
v
=
idx_data
[
i
];
if
(
to_local_node
->
insert
({
v
,
samples
->
size
()}).
second
)
samples
->
push_back
(
v
);
}
}
else
{
std
::
unordered_set
<
int64_t
>
indices
;
for
(
int64_t
i
=
population
-
num_samples
;
i
<
population
;
i
++
)
{
int64_t
j
=
rand
()
%
i
;
if
(
!
indices
.
insert
(
j
).
second
)
{
j
=
i
;
indices
.
insert
(
j
);
}
const
int64_t
&
v
=
idx_data
[
j
];
if
(
to_local_node
->
insert
({
v
,
samples
->
size
()}).
second
)
samples
->
push_back
(
v
);
}
}
}
csrc/hgt_sample.cpp
View file @
014c4bae
...
@@ -11,6 +11,7 @@ PyMODINIT_FUNC PyInit__hgt_sample_cpu(void) { return NULL; }
...
@@ -11,6 +11,7 @@ PyMODINIT_FUNC PyInit__hgt_sample_cpu(void) { return NULL; }
#endif
#endif
#endif
#endif
// Returns 'output_node_dict', 'row_dict', 'col_dict', 'output_edge_dict'
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hgt_sample
(
const
c10
::
Dict
<
std
::
string
,
torch
::
Tensor
>
&
colptr_dict
,
hgt_sample
(
const
c10
::
Dict
<
std
::
string
,
torch
::
Tensor
>
&
colptr_dict
,
...
...
csrc/neighbor_sample.cpp
0 → 100644
View file @
014c4bae
#include <Python.h>
#include <torch/script.h>
#include "cpu/neighbor_sample_cpu.h"
#ifdef _WIN32
#ifdef WITH_CUDA
PyMODINIT_FUNC
PyInit__neighbor_sample_cuda
(
void
)
{
return
NULL
;
}
#else
PyMODINIT_FUNC
PyInit__neighbor_sample_cpu
(
void
)
{
return
NULL
;
}
#endif
#endif
// Returns 'output_node', 'row', 'col', 'output_edge'
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
neighbor_sample
(
const
torch
::
Tensor
&
colptr
,
const
torch
::
Tensor
&
row
,
const
torch
::
Tensor
&
input_node
,
const
std
::
vector
<
int64_t
>
num_neighbors
,
const
bool
replace
,
const
bool
directed
)
{
return
neighbor_sample_cpu
(
colptr
,
row
,
input_node
,
num_neighbors
,
replace
,
directed
);
}
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hetero_neighbor_sample
(
const
std
::
vector
<
node_t
>
&
node_types
,
const
std
::
vector
<
edge_t
>
&
edge_types
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
colptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
rel_t
,
std
::
vector
<
int64_t
>>
&
num_neighbors_dict
,
const
int64_t
num_hops
,
const
bool
replace
,
const
bool
directed
)
{
return
hetero_neighbor_sample_cpu
(
node_types
,
edge_types
,
colptr_dict
,
row_dict
,
input_node_dict
,
num_neighbors_dict
,
num_hops
,
replace
,
directed
);
}
static
auto
registry
=
torch
::
RegisterOperators
()
.
op
(
"torch_sparse::neighbor_sample"
,
&
neighbor_sample
)
.
op
(
"torch_sparse::hetero_neighbor_sample"
,
&
hetero_neighbor_sample
);
torch_sparse/__init__.py
View file @
014c4bae
...
@@ -9,7 +9,8 @@ suffix = 'cuda' if torch.cuda.is_available() else 'cpu'
...
@@ -9,7 +9,8 @@ suffix = 'cuda' if torch.cuda.is_available() else 'cpu'
for
library
in
[
for
library
in
[
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
,
'_metis'
,
'_rw'
,
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
,
'_metis'
,
'_rw'
,
'_saint'
,
'_sample'
,
'_ego_sample'
,
'_hgt_sample'
,
'_relabel'
'_saint'
,
'_sample'
,
'_ego_sample'
,
'_hgt_sample'
,
'_neighbor_sample'
,
'_relabel'
]:
]:
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
f
'
{
library
}
_
{
suffix
}
'
,
[
osp
.
dirname
(
__file__
)]).
origin
)
f
'
{
library
}
_
{
suffix
}
'
,
[
osp
.
dirname
(
__file__
)]).
origin
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment