Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
6143af21
Unverified
Commit
6143af21
authored
Aug 11, 2022
by
Dong Wang
Committed by
GitHub
Aug 11, 2022
Browse files
use batch idx and node id as unique key for dedup in temporal sampling (#267)
Co-authored-by:
Dong Wang
<
d@dongs-mbp.lan
>
parent
916ba55b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
22 deletions
+73
-22
csrc/cpu/neighbor_sample_cpu.cpp
csrc/cpu/neighbor_sample_cpu.cpp
+73
-22
No files found.
csrc/cpu/neighbor_sample_cpu.cpp
View file @
6143af21
...
@@ -10,6 +10,8 @@ using namespace std;
...
@@ -10,6 +10,8 @@ using namespace std;
namespace
{
namespace
{
typedef
phmap
::
flat_hash_map
<
pair
<
int64_t
,
int64_t
>
,
int64_t
>
temporarl_edge_dict
;
template
<
bool
replace
,
bool
directed
>
template
<
bool
replace
,
bool
directed
>
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
>
sample
(
const
torch
::
Tensor
&
colptr
,
const
torch
::
Tensor
&
row
,
sample
(
const
torch
::
Tensor
&
colptr
,
const
torch
::
Tensor
&
row
,
...
@@ -146,11 +148,15 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -146,11 +148,15 @@ hetero_sample(const vector<node_t> &node_types,
// Initialize some data structures for the sampling process:
// Initialize some data structures for the sampling process:
phmap
::
flat_hash_map
<
node_t
,
vector
<
int64_t
>>
samples_dict
;
phmap
::
flat_hash_map
<
node_t
,
vector
<
int64_t
>>
samples_dict
;
phmap
::
flat_hash_map
<
node_t
,
vector
<
pair
<
int64_t
,
int64_t
>>>
temp_samples_dict
;
phmap
::
flat_hash_map
<
node_t
,
phmap
::
flat_hash_map
<
int64_t
,
int64_t
>>
to_local_node_dict
;
phmap
::
flat_hash_map
<
node_t
,
phmap
::
flat_hash_map
<
int64_t
,
int64_t
>>
to_local_node_dict
;
phmap
::
flat_hash_map
<
node_t
,
temporarl_edge_dict
>
temp_to_local_node_dict
;
phmap
::
flat_hash_map
<
node_t
,
vector
<
int64_t
>>
root_time_dict
;
phmap
::
flat_hash_map
<
node_t
,
vector
<
int64_t
>>
root_time_dict
;
for
(
const
auto
&
node_type
:
node_types
)
{
for
(
const
auto
&
node_type
:
node_types
)
{
samples_dict
[
node_type
];
samples_dict
[
node_type
];
temp_samples_dict
[
node_type
];
to_local_node_dict
[
node_type
];
to_local_node_dict
[
node_type
];
temp_to_local_node_dict
[
node_type
];
root_time_dict
[
node_type
];
root_time_dict
[
node_type
];
}
}
...
@@ -175,20 +181,33 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -175,20 +181,33 @@ hetero_sample(const vector<node_t> &node_types,
}
}
auto
&
samples
=
samples_dict
.
at
(
node_type
);
auto
&
samples
=
samples_dict
.
at
(
node_type
);
auto
&
temp_samples
=
temp_samples_dict
.
at
(
node_type
);
auto
&
to_local_node
=
to_local_node_dict
.
at
(
node_type
);
auto
&
to_local_node
=
to_local_node_dict
.
at
(
node_type
);
auto
&
temp_to_local_node
=
temp_to_local_node_dict
.
at
(
node_type
);
auto
&
root_time
=
root_time_dict
.
at
(
node_type
);
auto
&
root_time
=
root_time_dict
.
at
(
node_type
);
for
(
int64_t
i
=
0
;
i
<
input_node
.
numel
();
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
input_node
.
numel
();
i
++
)
{
const
auto
&
v
=
input_node_data
[
i
];
const
auto
&
v
=
input_node_data
[
i
];
if
(
temporal
)
{
temp_samples
.
push_back
({
v
,
i
});
temp_to_local_node
.
insert
({{
v
,
i
},
i
});
}
else
{
samples
.
push_back
(
v
);
samples
.
push_back
(
v
);
to_local_node
.
insert
({
v
,
i
});
to_local_node
.
insert
({
v
,
i
});
}
if
(
temporal
)
if
(
temporal
)
root_time
.
push_back
(
node_time_data
[
v
]);
root_time
.
push_back
(
node_time_data
[
v
]);
}
}
}
}
phmap
::
flat_hash_map
<
node_t
,
pair
<
int64_t
,
int64_t
>>
slice_dict
;
phmap
::
flat_hash_map
<
node_t
,
pair
<
int64_t
,
int64_t
>>
slice_dict
;
if
(
temporal
)
{
for
(
const
auto
&
kv
:
temp_samples_dict
)
{
slice_dict
[
kv
.
first
]
=
{
0
,
kv
.
second
.
size
()};
}
}
else
{
for
(
const
auto
&
kv
:
samples_dict
)
for
(
const
auto
&
kv
:
samples_dict
)
slice_dict
[
kv
.
first
]
=
{
0
,
kv
.
second
.
size
()};
slice_dict
[
kv
.
first
]
=
{
0
,
kv
.
second
.
size
()};
}
vector
<
rel_t
>
all_rel_types
;
vector
<
rel_t
>
all_rel_types
;
for
(
const
auto
&
kv
:
num_neighbors_dict
)
{
for
(
const
auto
&
kv
:
num_neighbors_dict
)
{
...
@@ -203,8 +222,11 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -203,8 +222,11 @@ hetero_sample(const vector<node_t> &node_types,
const
auto
&
dst_node_type
=
get
<
2
>
(
edge_type
);
const
auto
&
dst_node_type
=
get
<
2
>
(
edge_type
);
const
auto
num_samples
=
num_neighbors_dict
.
at
(
rel_type
)[
ell
];
const
auto
num_samples
=
num_neighbors_dict
.
at
(
rel_type
)[
ell
];
const
auto
&
dst_samples
=
samples_dict
.
at
(
dst_node_type
);
const
auto
&
dst_samples
=
samples_dict
.
at
(
dst_node_type
);
const
auto
&
temp_dst_samples
=
temp_samples_dict
.
at
(
dst_node_type
);
auto
&
src_samples
=
samples_dict
.
at
(
src_node_type
);
auto
&
src_samples
=
samples_dict
.
at
(
src_node_type
);
auto
&
temp_src_samples
=
temp_samples_dict
.
at
(
src_node_type
);
auto
&
to_local_src_node
=
to_local_node_dict
.
at
(
src_node_type
);
auto
&
to_local_src_node
=
to_local_node_dict
.
at
(
src_node_type
);
auto
&
temp_to_local_src_node
=
temp_to_local_node_dict
.
at
(
src_node_type
);
const
torch
::
Tensor
&
colptr
=
colptr_dict
.
at
(
rel_type
);
const
torch
::
Tensor
&
colptr
=
colptr_dict
.
at
(
rel_type
);
const
auto
*
colptr_data
=
colptr
.
data_ptr
<
int64_t
>
();
const
auto
*
colptr_data
=
colptr
.
data_ptr
<
int64_t
>
();
...
@@ -223,7 +245,8 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -223,7 +245,8 @@ hetero_sample(const vector<node_t> &node_types,
const
auto
&
begin
=
slice_dict
.
at
(
dst_node_type
).
first
;
const
auto
&
begin
=
slice_dict
.
at
(
dst_node_type
).
first
;
const
auto
&
end
=
slice_dict
.
at
(
dst_node_type
).
second
;
const
auto
&
end
=
slice_dict
.
at
(
dst_node_type
).
second
;
for
(
int64_t
i
=
begin
;
i
<
end
;
i
++
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
i
++
)
{
const
auto
&
w
=
dst_samples
[
i
];
const
auto
&
w
=
temporal
?
temp_dst_samples
[
i
].
first
:
dst_samples
[
i
];
const
int64_t
root_w
=
temporal
?
temp_dst_samples
[
i
].
second
:
-
1
;
int64_t
dst_time
=
0
;
int64_t
dst_time
=
0
;
if
(
temporal
)
if
(
temporal
)
dst_time
=
dst_root_time
[
i
];
dst_time
=
dst_root_time
[
i
];
...
@@ -241,15 +264,18 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -241,15 +264,18 @@ hetero_sample(const vector<node_t> &node_types,
if
(
temporal
)
{
if
(
temporal
)
{
if
(
!
satisfy_time
(
node_time_dict
,
src_node_type
,
dst_time
,
v
))
if
(
!
satisfy_time
(
node_time_dict
,
src_node_type
,
dst_time
,
v
))
continue
;
continue
;
// force disjoint of computation tree
// force disjoint of computation tree
based on source batch idx.
// note that the sampling always needs to have directed=True
// note that the sampling always needs to have directed=True
// for temporal case
// for temporal case
// to_local_src_node is not used for temporal / directed case
// to_local_src_node is not used for temporal / directed case
const
int64_t
sample_idx
=
src_samples
.
size
();
const
auto
res
=
temp_to_local_src_node
.
insert
({{
v
,
root_w
},
(
int64_t
)
temp_src_samples
.
size
()});
src_samples
.
push_back
(
v
);
if
(
res
.
second
)
{
temp_src_samples
.
push_back
({
v
,
root_w
});
src_root_time
.
push_back
(
dst_time
);
src_root_time
.
push_back
(
dst_time
);
}
cols
.
push_back
(
i
);
cols
.
push_back
(
i
);
rows
.
push_back
(
sample_idx
);
rows
.
push_back
(
res
.
first
->
second
);
edges
.
push_back
(
offset
);
edges
.
push_back
(
offset
);
}
else
{
}
else
{
const
auto
res
=
to_local_src_node
.
insert
({
v
,
src_samples
.
size
()});
const
auto
res
=
to_local_src_node
.
insert
({
v
,
src_samples
.
size
()});
...
@@ -272,14 +298,17 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -272,14 +298,17 @@ hetero_sample(const vector<node_t> &node_types,
// TODO Infinity loop if no neighbor satisfies time constraint:
// TODO Infinity loop if no neighbor satisfies time constraint:
if
(
!
satisfy_time
(
node_time_dict
,
src_node_type
,
dst_time
,
v
))
if
(
!
satisfy_time
(
node_time_dict
,
src_node_type
,
dst_time
,
v
))
continue
;
continue
;
// force disjoint of computation tree
// force disjoint of computation tree
based on source batch idx.
// note that the sampling always needs to have directed=True
// note that the sampling always needs to have directed=True
// for temporal case
// for temporal case
const
int64_t
sample_idx
=
src_samples
.
size
();
const
auto
res
=
temp_to_local_src_node
.
insert
({{
v
,
root_w
},
(
int64_t
)
temp_src_samples
.
size
()});
src_samples
.
push_back
(
v
);
if
(
res
.
second
)
{
temp_src_samples
.
push_back
({
v
,
root_w
});
src_root_time
.
push_back
(
dst_time
);
src_root_time
.
push_back
(
dst_time
);
}
cols
.
push_back
(
i
);
cols
.
push_back
(
i
);
rows
.
push_back
(
sample_idx
);
rows
.
push_back
(
res
.
first
->
second
);
edges
.
push_back
(
offset
);
edges
.
push_back
(
offset
);
}
else
{
}
else
{
const
auto
res
=
to_local_src_node
.
insert
({
v
,
src_samples
.
size
()});
const
auto
res
=
to_local_src_node
.
insert
({
v
,
src_samples
.
size
()});
...
@@ -307,14 +336,17 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -307,14 +336,17 @@ hetero_sample(const vector<node_t> &node_types,
if
(
temporal
)
{
if
(
temporal
)
{
if
(
!
satisfy_time
(
node_time_dict
,
src_node_type
,
dst_time
,
v
))
if
(
!
satisfy_time
(
node_time_dict
,
src_node_type
,
dst_time
,
v
))
continue
;
continue
;
// force disjoint of computation tree
// force disjoint of computation tree
based on source batch idx.
// note that the sampling always needs to have directed=True
// note that the sampling always needs to have directed=True
// for temporal case
// for temporal case
const
int64_t
sample_idx
=
src_samples
.
size
();
const
auto
res
=
temp_to_local_src_node
.
insert
({{
v
,
root_w
},
(
int64_t
)
temp_src_samples
.
size
()});
src_samples
.
push_back
(
v
);
if
(
res
.
second
)
{
temp_src_samples
.
push_back
({
v
,
root_w
});
src_root_time
.
push_back
(
dst_time
);
src_root_time
.
push_back
(
dst_time
);
}
cols
.
push_back
(
i
);
cols
.
push_back
(
i
);
rows
.
push_back
(
sample_idx
);
rows
.
push_back
(
res
.
first
->
second
);
edges
.
push_back
(
offset
);
edges
.
push_back
(
offset
);
}
else
{
}
else
{
const
auto
res
=
to_local_src_node
.
insert
({
v
,
src_samples
.
size
()});
const
auto
res
=
to_local_src_node
.
insert
({
v
,
src_samples
.
size
()});
...
@@ -331,11 +363,18 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -331,11 +363,18 @@ hetero_sample(const vector<node_t> &node_types,
}
}
}
}
for
(
const
auto
&
kv
:
samples_dict
)
{
if
(
temporal
)
{
slice_dict
[
kv
.
first
]
=
{
slice_dict
.
at
(
kv
.
first
).
second
,
kv
.
second
.
size
()};
for
(
const
auto
&
kv
:
temp_samples_dict
)
{
slice_dict
[
kv
.
first
]
=
{
0
,
kv
.
second
.
size
()};
}
}
else
{
for
(
const
auto
&
kv
:
samples_dict
)
slice_dict
[
kv
.
first
]
=
{
0
,
kv
.
second
.
size
()};
}
}
}
}
// Temporal sample disable undirected
assert
(
!
(
temporal
&&
!
directed
));
if
(
!
directed
)
{
// Construct the subgraph among the sampled nodes:
if
(
!
directed
)
{
// Construct the subgraph among the sampled nodes:
phmap
::
flat_hash_map
<
int64_t
,
int64_t
>::
iterator
iter
;
phmap
::
flat_hash_map
<
int64_t
,
int64_t
>::
iterator
iter
;
for
(
const
auto
&
kv
:
colptr_dict
)
{
for
(
const
auto
&
kv
:
colptr_dict
)
{
...
@@ -371,6 +410,18 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -371,6 +410,18 @@ hetero_sample(const vector<node_t> &node_types,
}
}
}
}
// Construct samples dictionary from temporal sample dictionary.
if
(
temporal
)
{
for
(
const
auto
&
kv
:
temp_samples_dict
)
{
const
auto
&
node_type
=
kv
.
first
;
const
auto
&
samples
=
kv
.
second
;
samples_dict
[
node_type
].
reserve
(
samples
.
size
());
for
(
const
auto
&
v
:
samples
)
{
samples_dict
[
node_type
].
push_back
(
v
.
first
);
}
}
}
return
make_tuple
(
from_vector
<
node_t
,
int64_t
>
(
samples_dict
),
return
make_tuple
(
from_vector
<
node_t
,
int64_t
>
(
samples_dict
),
from_vector
<
rel_t
,
int64_t
>
(
rows_dict
),
from_vector
<
rel_t
,
int64_t
>
(
rows_dict
),
from_vector
<
rel_t
,
int64_t
>
(
cols_dict
),
from_vector
<
rel_t
,
int64_t
>
(
cols_dict
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment