Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
0172aeb3
Unverified
Commit
0172aeb3
authored
Apr 25, 2022
by
Matthias Fey
Committed by
GitHub
Apr 25, 2022
Browse files
Temporal neighbor sampling adjustments (part2) (#226)
* temporal neighbor sampling adjustments (part2) * fix
parent
caf7ddde
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
25 deletions
+24
-25
csrc/cpu/neighbor_sample_cpu.cpp
csrc/cpu/neighbor_sample_cpu.cpp
+24
-25
No files found.
csrc/cpu/neighbor_sample_cpu.cpp
View file @
0172aeb3
...
@@ -115,11 +115,11 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
...
@@ -115,11 +115,11 @@ sample(const torch::Tensor &colptr, const torch::Tensor &row,
}
}
inline
bool
satisfy_time
(
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
node_time_dict
,
inline
bool
satisfy_time
(
const
c10
::
Dict
<
node_t
,
torch
::
Tensor
>
&
node_time_dict
,
const
node_t
&
src_node_type
,
const
int64_t
&
dst_time
,
const
node_t
&
src_node_type
,
int64_t
dst_time
,
const
int64_t
&
src_node
)
{
int64_t
src_node
)
{
try
{
// Check whether src -> dst obeys the time constraint:
try
{
// Check whether src -> dst obeys the time constraint:
auto
src
_time
=
node_time_dict
.
at
(
src_node_type
)
.
data_ptr
<
int64_t
>
()
;
const
torch
::
Tensor
&
src_node
_time
=
node_time_dict
.
at
(
src_node_type
);
return
dst_time
<
src_time
[
src_node
]
;
return
src_node_time
.
data_ptr
<
int64_t
>
()[
src_node
]
<=
dst_time
;
}
catch
(
int
err
)
{
// If no time is given, fall back to normal sampling:
}
catch
(
int
err
)
{
// If no time is given, fall back to normal sampling:
return
true
;
return
true
;
}
}
...
@@ -143,14 +143,6 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -143,14 +143,6 @@ hetero_sample(const vector<node_t> &node_types,
to_edge_type
[
get
<
0
>
(
k
)
+
"__"
+
get
<
1
>
(
k
)
+
"__"
+
get
<
2
>
(
k
)]
=
k
;
to_edge_type
[
get
<
0
>
(
k
)
+
"__"
+
get
<
1
>
(
k
)
+
"__"
+
get
<
2
>
(
k
)]
=
k
;
// Initialize some data structures for the sampling process:
// Initialize some data structures for the sampling process:
unordered_map
<
rel_t
,
vector
<
int64_t
>>
rows_dict
,
cols_dict
,
edges_dict
;
for
(
const
auto
&
kv
:
colptr_dict
)
{
const
auto
&
rel_type
=
kv
.
key
();
rows_dict
[
rel_type
];
cols_dict
[
rel_type
];
edges_dict
[
rel_type
];
}
unordered_map
<
node_t
,
vector
<
int64_t
>>
samples_dict
;
unordered_map
<
node_t
,
vector
<
int64_t
>>
samples_dict
;
unordered_map
<
node_t
,
unordered_map
<
int64_t
,
int64_t
>>
to_local_node_dict
;
unordered_map
<
node_t
,
unordered_map
<
int64_t
,
int64_t
>>
to_local_node_dict
;
unordered_map
<
node_t
,
vector
<
int64_t
>>
root_time_dict
;
unordered_map
<
node_t
,
vector
<
int64_t
>>
root_time_dict
;
...
@@ -160,14 +152,23 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -160,14 +152,23 @@ hetero_sample(const vector<node_t> &node_types,
root_time_dict
[
node_type
];
root_time_dict
[
node_type
];
}
}
unordered_map
<
rel_t
,
vector
<
int64_t
>>
rows_dict
,
cols_dict
,
edges_dict
;
for
(
const
auto
&
kv
:
colptr_dict
)
{
const
auto
&
rel_type
=
kv
.
key
();
rows_dict
[
rel_type
];
cols_dict
[
rel_type
];
edges_dict
[
rel_type
];
}
// Add the input nodes to the output nodes:
// Add the input nodes to the output nodes:
for
(
const
auto
&
kv
:
input_node_dict
)
{
for
(
const
auto
&
kv
:
input_node_dict
)
{
const
auto
&
node_type
=
kv
.
key
();
const
auto
&
node_type
=
kv
.
key
();
const
torch
::
Tensor
&
input_node
=
kv
.
value
();
const
torch
::
Tensor
&
input_node
=
kv
.
value
();
const
auto
*
input_node_data
=
input_node
.
data_ptr
<
int64_t
>
();
const
auto
*
input_node_data
=
input_node
.
data_ptr
<
int64_t
>
();
int64_t
*
node_time_data
;
int64_t
*
node_time_data
;
if
(
temporal
)
{
if
(
temporal
)
{
torch
::
Tensor
node_time
=
node_time_dict
.
at
(
node_type
);
const
torch
::
Tensor
&
node_time
=
node_time_dict
.
at
(
node_type
);
node_time_data
=
node_time
.
data_ptr
<
int64_t
>
();
node_time_data
=
node_time
.
data_ptr
<
int64_t
>
();
}
}
...
@@ -198,29 +199,27 @@ hetero_sample(const vector<node_t> &node_types,
...
@@ -198,29 +199,27 @@ hetero_sample(const vector<node_t> &node_types,
auto
&
src_samples
=
samples_dict
.
at
(
src_node_type
);
auto
&
src_samples
=
samples_dict
.
at
(
src_node_type
);
auto
&
to_local_src_node
=
to_local_node_dict
.
at
(
src_node_type
);
auto
&
to_local_src_node
=
to_local_node_dict
.
at
(
src_node_type
);
const
auto
*
colptr_data
=
const
torch
::
Tensor
&
colptr
=
colptr_dict
.
at
(
rel_type
);
((
torch
::
Tensor
)
colptr_dict
.
at
(
rel_type
))
.
data_ptr
<
int64_t
>
();
const
auto
*
colptr_data
=
colptr
.
data_ptr
<
int64_t
>
();
const
auto
*
row_data
=
const
torch
::
Tensor
&
row
=
row_dict
.
at
(
rel_type
);
((
torch
::
Tensor
)
row_dict
.
at
(
rel_type
))
.
data_ptr
<
int64_t
>
();
const
auto
*
row_data
=
row
.
data_ptr
<
int64_t
>
();
auto
&
rows
=
rows_dict
.
at
(
rel_type
);
auto
&
rows
=
rows_dict
.
at
(
rel_type
);
auto
&
cols
=
cols_dict
.
at
(
rel_type
);
auto
&
cols
=
cols_dict
.
at
(
rel_type
);
auto
&
edges
=
edges_dict
.
at
(
rel_type
);
auto
&
edges
=
edges_dict
.
at
(
rel_type
);
const
auto
&
begin
=
slice_dict
.
at
(
dst_node_type
).
first
;
const
auto
&
end
=
slice_dict
.
at
(
dst_node_type
).
second
;
if
(
begin
==
end
)
continue
;
// For temporal sampling, sampled nodes cannot have a timestamp greater
// For temporal sampling, sampled nodes cannot have a timestamp greater
// than the timestamp of the root nodes
.
// than the timestamp of the root nodes
:
const
auto
&
dst_root_time
=
root_time_dict
.
at
(
dst_node_type
);
const
auto
&
dst_root_time
=
root_time_dict
.
at
(
dst_node_type
);
auto
&
src_root_time
=
root_time_dict
.
at
(
src_node_type
);
auto
&
src_root_time
=
root_time_dict
.
at
(
src_node_type
);
const
auto
&
begin
=
slice_dict
.
at
(
dst_node_type
).
first
;
const
auto
&
end
=
slice_dict
.
at
(
dst_node_type
).
second
;
for
(
int64_t
i
=
begin
;
i
<
end
;
i
++
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
i
++
)
{
const
auto
&
w
=
dst_samples
[
i
];
const
auto
&
w
=
dst_samples
[
i
];
const
auto
&
dst_time
=
dst_root_time
[
i
];
int64_t
dst_time
=
0
;
if
(
temporal
)
dst_time
=
dst_root_time
[
i
];
const
auto
&
col_start
=
colptr_data
[
w
];
const
auto
&
col_start
=
colptr_data
[
w
];
const
auto
&
col_end
=
colptr_data
[
w
+
1
];
const
auto
&
col_end
=
colptr_data
[
w
+
1
];
const
auto
col_count
=
col_end
-
col_start
;
const
auto
col_count
=
col_end
-
col_start
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment