Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
4bff7c3f
Commit
4bff7c3f
authored
Jul 13, 2021
by
rusty1s
Browse files
update
parent
e4cac317
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
94 deletions
+75
-94
csrc/cpu/hgt_sample_cpu.cpp
csrc/cpu/hgt_sample_cpu.cpp
+70
-69
csrc/cpu/hgt_sample_cpu.h
csrc/cpu/hgt_sample_cpu.h
+2
-22
csrc/hgt_sample.cpp
csrc/hgt_sample.cpp
+3
-3
No files found.
csrc/cpu/hgt_sample_cpu.cpp
View file @
4bff7c3f
...
@@ -13,18 +13,23 @@ edge_t split(const rel_t &rel_type) {
...
@@ -13,18 +13,23 @@ edge_t split(const rel_t &rel_type) {
return
std
::
make_tuple
(
result
[
0
],
result
[
1
],
result
[
2
]);
return
std
::
make_tuple
(
result
[
0
],
result
[
1
],
result
[
2
]);
}
}
torch
::
Tensor
vec_to_tensor
(
const
std
::
vector
<
int64_t
>
&
v
)
{
return
torch
::
from_blob
((
int64_t
*
)
v
.
data
(),
{(
int64_t
)
v
.
size
()},
at
::
kLong
)
.
clone
();
}
template
<
typename
Container
>
void
update_budget
(
void
update_budget
(
std
::
unordered_map
<
node_t
,
std
::
unordered_map
<
int64_t
,
float
>>
*
budget_dict
,
std
::
unordered_map
<
node_t
,
std
::
unordered_map
<
int64_t
,
float
>>
*
budget_dict
,
const
node_t
&
node_type
,
//
const
node_t
&
node_type
,
//
const
std
::
vector
<
int64_t
>
&
sampled_nodes
,
const
Container
&
sampled_nodes
,
const
std
::
unordered_map
<
node_t
,
std
::
unordered_map
<
int64_t
,
int64_t
>>
const
std
::
unordered_map
<
node_t
,
std
::
unordered_map
<
int64_t
,
int64_t
>>
&
global_to_local_node_dict
,
&
global_to_local_node_dict
,
const
std
::
unordered_map
<
rel_t
,
edge_t
>
&
rel_to_edge_type
,
const
std
::
unordered_map
<
rel_t
,
edge_t
>
&
rel_to_edge_type
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
rowptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
colptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
col_dict
,
//
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row_dict
)
{
const
bool
remove
)
{
for
(
const
auto
&
kv
:
row
ptr_dict
)
{
for
(
const
auto
&
kv
:
col
ptr_dict
)
{
const
auto
&
rel_type
=
kv
.
key
();
const
auto
&
rel_type
=
kv
.
key
();
const
auto
&
edge_type
=
rel_to_edge_type
.
at
(
rel_type
);
const
auto
&
edge_type
=
rel_to_edge_type
.
at
(
rel_type
);
const
auto
&
src_node_type
=
std
::
get
<
0
>
(
edge_type
);
const
auto
&
src_node_type
=
std
::
get
<
0
>
(
edge_type
);
...
@@ -35,45 +40,49 @@ void update_budget(
...
@@ -35,45 +40,49 @@ void update_budget(
const
auto
&
global_to_local_node
=
const
auto
&
global_to_local_node
=
global_to_local_node_dict
.
at
(
src_node_type
);
global_to_local_node_dict
.
at
(
src_node_type
);
const
auto
*
row
ptr_data
=
kv
.
value
().
data_ptr
<
int64_t
>
();
const
auto
*
col
ptr_data
=
kv
.
value
().
data_ptr
<
int64_t
>
();
const
auto
*
col
_data
=
col
_dict
.
at
(
rel_type
).
data_ptr
<
int64_t
>
();
const
auto
*
row
_data
=
row
_dict
.
at
(
rel_type
).
data_ptr
<
int64_t
>
();
auto
&
budget
=
(
*
budget_dict
)[
src_node_type
];
auto
&
budget
=
(
*
budget_dict
)[
src_node_type
];
for
(
const
auto
&
v
:
sampled_nodes
)
{
for
(
const
auto
&
v
:
sampled_nodes
)
{
const
int64_t
row_start
=
rowptr_data
[
v
],
row_end
=
rowptr_data
[
v
+
1
];
const
int64_t
col_start
=
colptr_data
[
v
],
col_end
=
colptr_data
[
v
+
1
];
if
(
row_end
!=
row_start
)
{
if
(
col_end
!=
col_start
)
{
const
auto
inv_deg
=
1.
f
/
float
(
row_end
-
row_start
);
const
auto
inv_deg
=
1.
f
/
float
(
col_end
-
col_start
);
for
(
int64_t
j
=
row_start
;
j
<
row_end
;
j
++
)
{
for
(
int64_t
j
=
col_start
;
j
<
col_end
;
j
++
)
{
const
auto
w
=
col_data
[
j
];
const
auto
w
=
row_data
[
j
];
// Only add the neighbor in case we have not yet seen it before:
if
(
global_to_local_node
.
find
(
w
)
==
global_to_local_node
.
end
())
if
(
global_to_local_node
.
find
(
w
)
==
global_to_local_node
.
end
())
budget
[
col
_data
[
j
]]
+=
inv_deg
;
budget
[
row
_data
[
j
]]
+=
inv_deg
;
}
}
}
}
}
}
}
}
if
(
remove
)
{
auto
&
budget
=
(
*
budget_dict
)[
node_type
];
auto
&
budget
=
(
*
budget_dict
)[
node_type
];
for
(
const
auto
&
v
:
sampled_nodes
)
for
(
const
auto
&
v
:
sampled_nodes
)
budget
.
erase
(
v
);
budget
.
erase
(
v
);
}
}
}
std
::
unordered_set
<
int64_t
>
std
::
unordered_set
<
int64_t
>
sample_from
(
const
std
::
unordered_map
<
int64_t
,
float
>
&
budget
,
sample_from
(
const
std
::
unordered_map
<
int64_t
,
float
>
&
budget
,
const
int64_t
num_samples
)
{
const
int64_t
num_samples
)
{
std
::
unordered_set
<
int64_t
>
output
;
// Compute the squared L2 norm:
// Compute the squared L2 norm:
auto
norm
=
0.
f
;
auto
norm
=
0.
f
;
for
(
const
auto
&
kv
:
budget
)
for
(
const
auto
&
kv
:
budget
)
norm
+=
kv
.
second
*
kv
.
second
;
norm
+=
kv
.
second
*
kv
.
second
;
if
(
norm
==
0.
)
// No need to sample if there are no nodes in the budget:
return
output
;
// Generate `num_samples` sorted random values between `[0., norm)`:
// Generate `num_samples` sorted random values between `[0., norm)`:
std
::
vector
<
float
>
samples
(
num_samples
);
std
::
uniform_real_distribution
<
float
>
dist
(
0.
f
,
norm
);
std
::
default_random_engine
gen
{
std
::
random_device
{}()};
std
::
default_random_engine
gen
{
std
::
random_device
{}()};
std
::
generate
(
std
::
begin
(
samples
),
std
::
end
(
samples
),
std
::
uniform_real_distribution
<
float
>
dis
(
0.
f
,
norm
);
[
&
]
{
return
dist
(
gen
);
});
std
::
vector
<
float
>
samples
(
num_samples
);
for
(
int64_t
i
=
0
;
i
<
num_samples
;
i
++
)
samples
[
i
]
=
dis
(
gen
);
std
::
sort
(
samples
.
begin
(),
samples
.
end
());
std
::
sort
(
samples
.
begin
(),
samples
.
end
());
// Iterate through the budget to compute the cumulative probability
// Iterate through the budget to compute the cumulative probability
...
@@ -82,7 +91,6 @@ sample_from(const std::unordered_map<int64_t, float> &budget,
...
@@ -82,7 +91,6 @@ sample_from(const std::unordered_map<int64_t, float> &budget,
// The implementation assigns two iterators on budget and samples,
// The implementation assigns two iterators on budget and samples,
// respectively, and then computes the node samples in linear time by
// respectively, and then computes the node samples in linear time by
// alternatingly incrementing the two iterators based on their values.
// alternatingly incrementing the two iterators based on their values.
std
::
unordered_set
<
int64_t
>
output
;
output
.
reserve
(
num_samples
);
output
.
reserve
(
num_samples
);
auto
j
=
samples
.
begin
();
auto
j
=
samples
.
begin
();
...
@@ -106,15 +114,15 @@ sample_from(const std::unordered_map<int64_t, float> &budget,
...
@@ -106,15 +114,15 @@ sample_from(const std::unordered_map<int64_t, float> &budget,
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hgt_sample_cpu
(
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row
ptr_dict
,
hgt_sample_cpu
(
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
col
ptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
col
_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row
_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
node_t
,
std
::
vector
<
int64_t
>>
&
num_samples_dict
,
const
c10
::
Dict
<
node_t
,
std
::
vector
<
int64_t
>>
&
num_samples_dict
,
int64_t
num_hops
)
{
int64_t
num_hops
)
{
// Create mapping to convert single string relations to edge type triplets:
// Create mapping to convert single string relations to edge type triplets:
std
::
unordered_map
<
rel_t
,
edge_t
>
rel_to_edge_type
;
std
::
unordered_map
<
rel_t
,
edge_t
>
rel_to_edge_type
;
for
(
const
auto
&
kv
:
row
ptr_dict
)
{
for
(
const
auto
&
kv
:
col
ptr_dict
)
{
const
auto
&
rel_type
=
kv
.
key
();
const
auto
&
rel_type
=
kv
.
key
();
rel_to_edge_type
[
rel_type
]
=
split
(
rel_type
);
rel_to_edge_type
[
rel_type
]
=
split
(
rel_type
);
}
}
...
@@ -131,7 +139,8 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
...
@@ -131,7 +139,8 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
budget_dict
[
node_type
];
budget_dict
[
node_type
];
}
}
// Add all input nodes of every node type to the sampled output set (line 1):
// Add all input nodes of every node type to the sampled output set, and
// compute initial budget (line 1-5):
for
(
const
auto
&
kv
:
input_node_dict
)
{
for
(
const
auto
&
kv
:
input_node_dict
)
{
const
auto
&
node_type
=
kv
.
key
();
const
auto
&
node_type
=
kv
.
key
();
const
auto
&
input_node
=
kv
.
value
();
const
auto
&
input_node
=
kv
.
value
();
...
@@ -140,19 +149,18 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
...
@@ -140,19 +149,18 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
auto
&
sampled_nodes
=
sampled_nodes_dict
.
at
(
node_type
);
auto
&
sampled_nodes
=
sampled_nodes_dict
.
at
(
node_type
);
auto
&
global_to_local_node
=
global_to_local_node_dict
.
at
(
node_type
);
auto
&
global_to_local_node
=
global_to_local_node_dict
.
at
(
node_type
);
// Add each origin node to the sampled output nodes:
// Add each origin node to the sampled output nodes
(line 1)
:
for
(
int64_t
i
=
0
;
i
<
input_node
.
numel
();
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
input_node
.
numel
();
i
++
)
{
const
auto
v
=
input_node_data
[
i
];
const
auto
v
=
input_node_data
[
i
];
sampled_nodes
.
push_back
(
v
);
sampled_nodes
.
push_back
(
v
);
global_to_local_node
[
v
]
=
i
;
global_to_local_node
[
v
]
=
i
;
}
}
}
// Update budget after
all
input nodes have been added to the sampled output
// Update budget after input nodes have been added to the sampled output
set
//
set
(line 2-5):
// (line 2-5):
for
(
const
auto
&
kv
:
sampled_nodes_dict
)
{
update_budget
<
std
::
vector
<
int64_t
>>
(
update_budget
(
&
budget_dict
,
kv
.
first
,
kv
.
second
,
global_to_local_node_dict
,
&
budget_dict
,
node_type
,
sampled_nodes
,
global_to_local_node_dict
,
rel_to_edge_type
,
row
ptr_dict
,
col
_dict
,
false
);
rel_to_edge_type
,
col
ptr_dict
,
row
_dict
);
}
}
// Sample nodes for each node type in each layer (line 6 - 18):
// Sample nodes for each node type in each layer (line 6 - 18):
...
@@ -166,21 +174,20 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
...
@@ -166,21 +174,20 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
// (line 9-11):
// (line 9-11):
const
auto
samples
=
sample_from
(
budget
,
num_samples
);
const
auto
samples
=
sample_from
(
budget
,
num_samples
);
if
(
samples
.
size
()
>
0
)
{
// Add sampled nodes to the sampled output set (line 13):
// Add sampled nodes to the sampled output set (line 13):
auto
&
sampled_nodes
=
sampled_nodes_dict
[
node_type
];
auto
&
sampled_nodes
=
sampled_nodes_dict
.
at
(
node_type
);
auto
&
global_to_local_node
=
global_to_local_node_dict
[
node_type
];
auto
&
global_to_local_node
=
global_to_local_node_dict
.
at
(
node_type
);
std
::
vector
<
int64_t
>
newly_sampled_nodes
;
newly_sampled_nodes
.
reserve
(
samples
.
size
());
for
(
const
auto
&
v
:
samples
)
{
for
(
const
auto
&
v
:
samples
)
{
sampled_nodes
.
push_back
(
v
);
sampled_nodes
.
push_back
(
v
);
newly_sampled_nodes
.
push_back
(
v
);
global_to_local_node
[
v
]
=
sampled_nodes
.
size
();
global_to_local_node
[
v
]
=
sampled_nodes
.
size
();
}
}
// Add neighbors of newly sampled nodes to the bucket (line 14-15):
// Add neighbors of newly sampled nodes to the bucket (line 14-15):
update_budget
(
&
budget_dict
,
node_type
,
newly_sampled_nodes
,
update_budget
<
std
::
unordered_set
<
int64_t
>>
(
global_to_local_node_dict
,
rel_to_edge_type
,
rowptr_dict
,
&
budget_dict
,
node_type
,
samples
,
global_to_local_node_dict
,
col_dict
,
true
);
rel_to_edge_type
,
colptr_dict
,
row_dict
);
}
}
}
}
}
...
@@ -188,14 +195,14 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
...
@@ -188,14 +195,14 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
output_row_dict
;
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
output_row_dict
;
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
output_col_dict
;
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
output_col_dict
;
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
output_edge_dict
;
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
output_edge_dict
;
for
(
const
auto
&
kv
:
row
ptr_dict
)
{
for
(
const
auto
&
kv
:
col
ptr_dict
)
{
const
auto
&
rel_type
=
kv
.
key
();
const
auto
&
rel_type
=
kv
.
key
();
const
auto
&
edge_type
=
rel_to_edge_type
.
at
(
rel_type
);
const
auto
&
edge_type
=
rel_to_edge_type
.
at
(
rel_type
);
const
auto
&
src_node_type
=
std
::
get
<
0
>
(
edge_type
);
const
auto
&
src_node_type
=
std
::
get
<
0
>
(
edge_type
);
const
auto
&
dst_node_type
=
std
::
get
<
2
>
(
edge_type
);
const
auto
&
dst_node_type
=
std
::
get
<
2
>
(
edge_type
);
const
auto
*
row
ptr_data
=
kv
.
value
().
data_ptr
<
int64_t
>
();
const
auto
*
col
ptr_data
=
kv
.
value
().
data_ptr
<
int64_t
>
();
const
auto
*
col
_data
=
col
_dict
.
at
(
rel_type
).
data_ptr
<
int64_t
>
();
const
auto
*
row
_data
=
row
_dict
.
at
(
rel_type
).
data_ptr
<
int64_t
>
();
const
auto
&
sampled_dst_nodes
=
sampled_nodes_dict
[
dst_node_type
];
const
auto
&
sampled_dst_nodes
=
sampled_nodes_dict
[
dst_node_type
];
const
auto
&
global_to_local_src
=
global_to_local_node_dict
[
src_node_type
];
const
auto
&
global_to_local_src
=
global_to_local_node_dict
[
src_node_type
];
...
@@ -203,35 +210,29 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
...
@@ -203,35 +210,29 @@ hgt_sample_cpu(const c10::Dict<rel_t, torch::Tensor> &rowptr_dict,
std
::
vector
<
int64_t
>
rows
,
cols
,
edges
;
std
::
vector
<
int64_t
>
rows
,
cols
,
edges
;
for
(
int64_t
i
=
0
;
i
<
(
int64_t
)
sampled_dst_nodes
.
size
();
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
(
int64_t
)
sampled_dst_nodes
.
size
();
i
++
)
{
const
auto
v
=
sampled_dst_nodes
[
i
];
const
auto
v
=
sampled_dst_nodes
[
i
];
const
int64_t
row
_start
=
row
ptr_data
[
v
],
row
_end
=
row
ptr_data
[
v
+
1
];
const
int64_t
col
_start
=
col
ptr_data
[
v
],
col
_end
=
col
ptr_data
[
v
+
1
];
for
(
int64_t
j
=
row
_start
;
j
<
row
_end
;
j
++
)
{
for
(
int64_t
j
=
col
_start
;
j
<
col
_end
;
j
++
)
{
const
auto
w
=
col
_data
[
j
];
const
auto
w
=
row
_data
[
j
];
if
(
global_to_local_src
.
find
(
w
)
!=
global_to_local_src
.
end
())
{
if
(
global_to_local_src
.
find
(
w
)
!=
global_to_local_src
.
end
())
{
rows
.
push_back
(
i
);
rows
.
push_back
(
global_to_local_src
.
at
(
w
)
);
cols
.
push_back
(
global_to_local_src
.
at
(
w
)
);
cols
.
push_back
(
i
);
edges
.
push_back
(
j
);
edges
.
push_back
(
j
);
}
}
}
}
}
}
torch
::
Tensor
out
;
if
(
rows
.
size
()
>
0
)
{
out
=
torch
::
from_blob
((
int64_t
*
)
rows
.
data
(),
{(
int64_t
)
rows
.
size
()},
output_row_dict
.
insert
(
rel_type
,
vec_to_tensor
(
rows
));
at
::
kLong
);
output_col_dict
.
insert
(
rel_type
,
vec_to_tensor
(
cols
));
output_row_dict
.
insert
(
rel_type
,
out
.
clone
());
output_edge_dict
.
insert
(
rel_type
,
vec_to_tensor
(
edges
));
out
=
torch
::
from_blob
((
int64_t
*
)
cols
.
data
(),
{(
int64_t
)
cols
.
size
()},
}
at
::
kLong
);
output_col_dict
.
insert
(
rel_type
,
out
.
clone
());
out
=
torch
::
from_blob
((
int64_t
*
)
edges
.
data
(),
{(
int64_t
)
edges
.
size
()},
at
::
kLong
);
output_edge_dict
.
insert
(
rel_type
,
out
.
clone
());
}
}
// Generate tensor-valued output node dict (line 20):
// Generate tensor-valued output node dict (line 20):
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
output_node_dict
;
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
output_node_dict
;
for
(
const
auto
&
kv
:
sampled_nodes_dict
)
{
for
(
const
auto
&
kv
:
sampled_nodes_dict
)
{
const
auto
out
=
torch
::
from_blob
((
int64_t
*
)
kv
.
second
.
data
(),
if
(
kv
.
second
.
size
()
>
0
)
{(
int64_t
)
kv
.
second
.
size
()},
at
::
kLong
);
output_node_dict
.
insert
(
kv
.
first
,
vec_to_tensor
(
kv
.
second
));
output_node_dict
.
insert
(
kv
.
first
,
out
.
clone
());
}
}
return
std
::
make_tuple
(
output_node_dict
,
output_row_dict
,
output_col_dict
,
return
std
::
make_tuple
(
output_node_dict
,
output_row_dict
,
output_col_dict
,
...
...
csrc/cpu/hgt_sample_cpu.h
View file @
4bff7c3f
// #pragma once
// #include <torch/extension.h>
// // Node type is a string and the edge type is a triplet of string
// representing
// // (source_node_type, relation_type, dest_node_type).
// typedef std::string node_t;
// typedef std::tuple<std::string, std::string, std::string> edge_t;
// // As of PyTorch 1.9.0, c10::Dict does not support tuples or complex data
// type as key type. We work around this
// // by representing edge types using a single int64_t and a c10::Dict that
// maps the int64_t index to edge_t. void hg_sample_cpu( const
// c10::Dict<int64_t, torch::Tensor> &rowptr_store, const c10::Dict<int64_t,
// torch::Tensor> &col_store, const c10::Dict<node_t, torch::Tensor>
// &origin_nodes_store, const c10::Dict<int64_t, edge_t>
// &edge_type_idx_to_name, int n, int num_layers
// );
//
#pragma once
#pragma once
#include <torch/extension.h>
#include <torch/extension.h>
...
@@ -30,8 +10,8 @@ const std::string delim = "__";
...
@@ -30,8 +10,8 @@ const std::string delim = "__";
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hgt_sample_cpu
(
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row
ptr_dict
,
hgt_sample_cpu
(
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
col
ptr_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
col
_dict
,
const
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
&
row
_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
node_t
,
std
::
vector
<
int64_t
>>
&
num_samples_dict
,
const
c10
::
Dict
<
node_t
,
std
::
vector
<
int64_t
>>
&
num_samples_dict
,
int64_t
num_hops
);
int64_t
num_hops
);
csrc/hgt_sample.cpp
View file @
4bff7c3f
...
@@ -13,13 +13,13 @@ PyMODINIT_FUNC PyInit__hgt_sample_cpu(void) { return NULL; }
...
@@ -13,13 +13,13 @@ PyMODINIT_FUNC PyInit__hgt_sample_cpu(void) { return NULL; }
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
std
::
tuple
<
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>
,
c10
::
Dict
<
rel_t
,
torch
::
Tensor
>>
hgt_sample
(
const
c10
::
Dict
<
std
::
string
,
torch
::
Tensor
>
&
row
ptr_dict
,
hgt_sample
(
const
c10
::
Dict
<
std
::
string
,
torch
::
Tensor
>
&
col
ptr_dict
,
const
c10
::
Dict
<
std
::
string
,
torch
::
Tensor
>
&
col
_dict
,
const
c10
::
Dict
<
std
::
string
,
torch
::
Tensor
>
&
row
_dict
,
const
c10
::
Dict
<
std
::
string
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
std
::
string
,
torch
::
Tensor
>
&
input_node_dict
,
const
c10
::
Dict
<
std
::
string
,
std
::
vector
<
int64_t
>>
&
num_samples_dict
,
const
c10
::
Dict
<
std
::
string
,
std
::
vector
<
int64_t
>>
&
num_samples_dict
,
const
int64_t
num_hops
)
{
const
int64_t
num_hops
)
{
return
hgt_sample_cpu
(
row
ptr_dict
,
col
_dict
,
input_node_dict
,
return
hgt_sample_cpu
(
col
ptr_dict
,
row
_dict
,
input_node_dict
,
num_samples_dict
,
num_hops
);
num_samples_dict
,
num_hops
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment