Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
658b2086
Unverified
Commit
658b2086
authored
Apr 28, 2024
by
Ramon Zhou
Committed by
GitHub
Apr 28, 2024
Browse files
[GraphBolt] Optimize hetero sampling on CPU (#7360)
parent
9090a879
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
448 additions
and
64 deletions
+448
-64
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+30
-12
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+413
-50
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
...n/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
+5
-2
No files found.
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
658b2086
...
@@ -415,6 +415,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -415,6 +415,13 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
private:
private:
template
<
typename
NumPickFn
,
typename
PickFn
>
template
<
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
SampleNeighborsImpl
(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
SampleNeighborsImpl
(
const
torch
::
Tensor
&
seeds
,
torch
::
optional
<
std
::
vector
<
int64_t
>>&
seed_offsets
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
;
template
<
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
TemporalSampleNeighborsImpl
(
const
torch
::
Tensor
&
nodes
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
const
torch
::
Tensor
&
nodes
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
;
PickFn
pick_fn
)
const
;
...
@@ -498,13 +505,14 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -498,13 +505,14 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @param offset The starting edge ID for the connected neighbors of the given
* @param offset The starting edge ID for the connected neighbors of the given
* node.
* node.
* @param num_neighbors The number of neighbors of this node.
* @param num_neighbors The number of neighbors of this node.
*
*
@param num_picked_ptr The pointer of the tensor which stores the pick
*
@return The pick number of the given node
.
*
numbers
.
*/
*/
int64_t
NumPick
(
template
<
typename
PickedNumType
>
void
NumPick
(
int64_t
fanout
,
bool
replace
,
int64_t
fanout
,
bool
replace
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
int64_t
num_neighbors
);
int64_t
num_neighbors
,
PickedNumType
*
num_picked_ptr
);
int64_t
TemporalNumPick
(
int64_t
TemporalNumPick
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indics
,
int64_t
fanout
,
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indics
,
int64_t
fanout
,
...
@@ -513,11 +521,13 @@ int64_t TemporalNumPick(
...
@@ -513,11 +521,13 @@ int64_t TemporalNumPick(
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
int64_t
seed_offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
edge_timestamp
,
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
);
int64_t
offset
,
int64_t
num_neighbors
);
int64_t
NumPickByEtype
(
template
<
typename
PickedNumType
>
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
void
NumPickByEtype
(
bool
with_seed_offsets
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
int64_t
num_neighbors
);
int64_t
num_neighbors
,
PickedNumType
*
num_picked_ptr
,
int64_t
seed_index
,
const
std
::
vector
<
int64_t
>&
etype_id_to_num_picked_offset
);
int64_t
TemporalNumPickByEtype
(
int64_t
TemporalNumPickByEtype
(
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
torch
::
Tensor
seed_timestamp
,
torch
::
Tensor
csc_indices
,
...
@@ -610,16 +620,24 @@ int64_t TemporalPick(
...
@@ -610,16 +620,24 @@ int64_t TemporalPick(
* probabilities associated with each neighboring edge of a node in the original
* probabilities associated with each neighboring edge of a node in the original
* graph. It must be a 1D floating-point tensor with the number of elements
* graph. It must be a 1D floating-point tensor with the number of elements
* equal to the number of edges in the graph.
* equal to the number of edges in the graph.
* @param picked_data_ptr The
destination address
where the picked neighbors
* @param picked_data_ptr The
pointer of the tensor
where the picked neighbors
* should be put. Enough memory space should be allocated in advance.
* should be put. Enough memory space should be allocated in advance.
* @param seed_offset The offset(index) of the seed among the group of seeds
* which share the same node type.
* @param subgraph_indptr_ptr The pointer of the tensor which stores the indptr
* of the sampled subgraph.
* @param etype_id_to_num_picked_offset A vector storing the mappings from each
* etype_id to the offset of its pick numbers in the tensor.
*/
*/
template
<
SamplerType
S
,
typename
PickedType
>
template
<
SamplerType
S
,
typename
PickedType
>
int64_t
PickByEtype
(
int64_t
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
with_seed_offsets
,
int64_t
offset
,
int64_t
num_neighbors
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
PickedType
*
picked_data_ptr
);
PickedType
*
picked_data_ptr
,
int64_t
seed_offset
,
PickedType
*
subgraph_indptr_ptr
,
const
std
::
vector
<
int64_t
>&
etype_id_to_num_picked_offset
);
template
<
typename
PickedType
>
template
<
typename
PickedType
>
int64_t
TemporalPickByEtype
(
int64_t
TemporalPickByEtype
(
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
658b2086
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include <type_traits>
#include <type_traits>
#include <vector>
#include <vector>
#include "./expand_indptr.h"
#include "./macro.h"
#include "./macro.h"
#include "./random.h"
#include "./random.h"
#include "./shared_memory_helper.h"
#include "./shared_memory_helper.h"
...
@@ -355,17 +356,23 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
...
@@ -355,17 +356,23 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
auto
GetNumPickFn
(
auto
GetNumPickFn
(
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
)
{
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
bool
with_seed_offsets
)
{
// If fanouts.size() > 1, returns the total number of all edge types of the
// If fanouts.size() > 1, returns the total number of all edge types of the
// given node.
// given node.
return
[
&
fanouts
,
replace
,
&
probs_or_mask
,
&
type_per_edge
](
return
[
&
fanouts
,
replace
,
&
probs_or_mask
,
&
type_per_edge
,
with_seed_offsets
](
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
)
{
int64_t
offset
,
int64_t
num_neighbors
,
auto
num_picked_ptr
,
int64_t
seed_index
,
const
std
::
vector
<
int64_t
>&
etype_id_to_num_picked_offset
)
{
if
(
fanouts
.
size
()
>
1
)
{
if
(
fanouts
.
size
()
>
1
)
{
return
NumPickByEtype
(
NumPickByEtype
(
fanouts
,
replace
,
type_per_edge
.
value
(),
probs_or_mask
,
offset
,
with_seed_offsets
,
fanouts
,
replace
,
type_per_edge
.
value
(),
num_neighbors
);
probs_or_mask
,
offset
,
num_neighbors
,
num_picked_ptr
,
seed_index
,
etype_id_to_num_picked_offset
);
}
else
{
}
else
{
return
NumPick
(
fanouts
[
0
],
replace
,
probs_or_mask
,
offset
,
num_neighbors
);
NumPick
(
fanouts
[
0
],
replace
,
probs_or_mask
,
offset
,
num_neighbors
,
num_picked_ptr
+
seed_index
);
}
}
};
};
}
}
...
@@ -423,21 +430,25 @@ auto GetPickFn(
...
@@ -423,21 +430,25 @@ auto GetPickFn(
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
)
{
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
bool
with_seed_offsets
,
return
[
&
fanouts
,
replace
,
&
options
,
&
type_per_edge
,
&
probs_or_mask
,
args
](
SamplerArgs
<
S
>
args
)
{
int64_t
seed_offset
,
int64_t
offset
,
int64_t
num_neighbors
,
return
[
&
fanouts
,
replace
,
&
options
,
&
type_per_edge
,
&
probs_or_mask
,
args
,
auto
picked_data_ptr
)
{
with_seed_offsets
](
int64_t
offset
,
int64_t
num_neighbors
,
auto
picked_data_ptr
,
int64_t
seed_offset
,
auto
subgraph_indptr_ptr
,
const
std
::
vector
<
int64_t
>&
etype_id_to_num_picked_offset
)
{
// If fanouts.size() > 1, perform sampling for each edge type of each
// If fanouts.size() > 1, perform sampling for each edge type of each
// node; otherwise just sample once for each node with no regard of edge
// node; otherwise just sample once for each node with no regard of edge
// types.
// types.
if
(
fanouts
.
size
()
>
1
)
{
if
(
fanouts
.
size
()
>
1
)
{
return
PickByEtype
(
return
PickByEtype
(
offset
,
num_neighbors
,
fanouts
,
replace
,
options
,
with_seed_offsets
,
offset
,
num_neighbors
,
fanouts
,
replace
,
options
,
type_per_edge
.
value
(),
probs_or_mask
,
args
,
picked_data_ptr
);
type_per_edge
.
value
(),
probs_or_mask
,
args
,
picked_data_ptr
,
seed_offset
,
subgraph_indptr_ptr
,
etype_id_to_num_picked_offset
);
}
else
{
}
else
{
int64_t
num_sampled
=
Pick
(
int64_t
num_sampled
=
Pick
(
offset
,
num_neighbors
,
fanouts
[
0
],
replace
,
options
,
probs_or_mask
,
offset
,
num_neighbors
,
fanouts
[
0
],
replace
,
options
,
probs_or_mask
,
args
,
picked_data_ptr
);
args
,
picked_data_ptr
+
subgraph_indptr_ptr
[
seed_offset
]
);
if
(
type_per_edge
)
{
if
(
type_per_edge
)
{
std
::
sort
(
picked_data_ptr
,
picked_data_ptr
+
num_sampled
);
std
::
sort
(
picked_data_ptr
,
picked_data_ptr
+
num_sampled
);
}
}
...
@@ -484,6 +495,304 @@ auto GetTemporalPickFn(
...
@@ -484,6 +495,304 @@ auto GetTemporalPickFn(
template
<
typename
NumPickFn
,
typename
PickFn
>
template
<
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighborsImpl
(
FusedCSCSamplingGraph
::
SampleNeighborsImpl
(
const
torch
::
Tensor
&
seeds
,
torch
::
optional
<
std
::
vector
<
int64_t
>>&
seed_offsets
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
{
const
int64_t
num_seeds
=
seeds
.
size
(
0
);
const
auto
indptr_options
=
indptr_
.
options
();
// Calculate GrainSize for parallel_for.
// Set the default grain size to 64.
const
int64_t
grain_size
=
64
;
torch
::
Tensor
picked_eids
;
torch
::
Tensor
subgraph_indptr
;
torch
::
Tensor
subgraph_indices
;
torch
::
optional
<
torch
::
Tensor
>
subgraph_type_per_edge
=
torch
::
nullopt
;
torch
::
optional
<
torch
::
Tensor
>
edge_offsets
=
torch
::
nullopt
;
bool
with_seed_offsets
=
seed_offsets
.
has_value
();
bool
hetero_with_seed_offsets
=
with_seed_offsets
&&
fanouts
.
size
()
>
1
;
// Get the number of edge types. If it's homo or if the size of fanouts is 1
// (hetero graph but sampled as a homo graph), set num_etypes as 1.
// In temporal sampling, this will not be used for now since the logic hasn't
// been adopted for temporal sampling.
const
int64_t
num_etypes
=
(
edge_type_to_id_
.
has_value
()
&&
hetero_with_seed_offsets
)
?
edge_type_to_id_
->
size
()
:
1
;
std
::
vector
<
int64_t
>
etype_id_to_src_ntype_id
(
num_etypes
);
std
::
vector
<
int64_t
>
etype_id_to_dst_ntype_id
(
num_etypes
);
torch
::
optional
<
torch
::
Tensor
>
subgraph_indptr_substract
=
torch
::
nullopt
;
// The pick numbers are stored in a single tensor by the order of etype. Each
// etype corresponds to a group of seeds whose ntype are the same as the
// dst_type. `etype_id_to_num_picked_offset` indicates the beginning offset
// where each etype's corresponding seeds' pick numbers are stored in the pick
// number tensor.
std
::
vector
<
int64_t
>
etype_id_to_num_picked_offset
(
num_etypes
+
1
);
if
(
hetero_with_seed_offsets
)
{
for
(
auto
&
etype_and_id
:
edge_type_to_id_
.
value
())
{
auto
etype
=
etype_and_id
.
key
();
auto
id
=
etype_and_id
.
value
();
auto
[
src_type
,
dst_type
]
=
utils
::
parse_src_dst_ntype_from_etype
(
etype
);
auto
dst_ntype_id
=
node_type_to_id_
->
at
(
dst_type
);
etype_id_to_src_ntype_id
[
id
]
=
node_type_to_id_
->
at
(
src_type
);
etype_id_to_dst_ntype_id
[
id
]
=
dst_ntype_id
;
etype_id_to_num_picked_offset
[
id
+
1
]
=
seed_offsets
->
at
(
dst_ntype_id
+
1
)
-
seed_offsets
->
at
(
dst_ntype_id
)
+
1
;
}
std
::
partial_sum
(
etype_id_to_num_picked_offset
.
begin
(),
etype_id_to_num_picked_offset
.
end
(),
etype_id_to_num_picked_offset
.
begin
());
}
else
{
etype_id_to_dst_ntype_id
[
0
]
=
0
;
etype_id_to_num_picked_offset
[
1
]
=
num_seeds
+
1
;
}
// `num_rows` indicates the length of `num_picked_neighbors_per_node`, which
// is used for storing pick numbers. In non-temporal hetero sampling, it
// equals to sum_{etype} #seeds with ntype=dst_type(etype). In homo sampling,
// it equals to `num_seeds`.
const
int64_t
num_rows
=
etype_id_to_num_picked_offset
[
num_etypes
];
torch
::
Tensor
num_picked_neighbors_per_node
=
torch
::
empty
({
num_rows
},
indptr_options
);
AT_DISPATCH_INDEX_TYPES
(
indptr_
.
scalar_type
(),
"SampleNeighborsImplWrappedWithIndptr"
,
([
&
]
{
using
indptr_t
=
index_t
;
AT_DISPATCH_INDEX_TYPES
(
seeds
.
scalar_type
(),
"SampleNeighborsImplWrappedWithSeeds"
,
([
&
]
{
using
seeds_t
=
index_t
;
const
auto
indptr_data
=
indptr_
.
data_ptr
<
indptr_t
>
();
const
auto
num_picked_neighbors_data_ptr
=
num_picked_neighbors_per_node
.
data_ptr
<
indptr_t
>
();
num_picked_neighbors_data_ptr
[
0
]
=
0
;
const
auto
seeds_data_ptr
=
seeds
.
data_ptr
<
seeds_t
>
();
// Initialize the empty spots in `num_picked_neighbors_per_node`.
if
(
hetero_with_seed_offsets
)
{
for
(
auto
i
=
0
;
i
<
num_etypes
;
++
i
)
{
num_picked_neighbors_data_ptr
[
etype_id_to_num_picked_offset
[
i
]]
=
0
;
}
}
// Step 1. Calculate pick number of each node.
torch
::
parallel_for
(
0
,
num_seeds
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
seeds_data_ptr
[
i
];
TORCH_CHECK
(
nid
>=
0
&&
nid
<
NumNodes
(),
"The seed nodes' IDs should fall within the range of "
"the graph's node IDs."
);
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
const
auto
seed_type_id
=
(
hetero_with_seed_offsets
)
?
std
::
upper_bound
(
seed_offsets
->
begin
(),
seed_offsets
->
end
(),
i
)
-
seed_offsets
->
begin
()
-
1
:
0
;
// `seed_index` indicates the index of the current
// seed within the group of seeds which have the same
// node type.
const
auto
seed_index
=
(
hetero_with_seed_offsets
)
?
i
-
seed_offsets
->
at
(
seed_type_id
)
:
i
;
num_pick_fn
(
offset
,
num_neighbors
,
num_picked_neighbors_data_ptr
+
1
,
seed_index
,
etype_id_to_num_picked_offset
);
}
});
if
(
hetero_with_seed_offsets
)
{
torch
::
Tensor
num_picked_offset_tensor
=
torch
::
zeros
({
num_etypes
+
1
},
indptr_options
);
torch
::
Tensor
substract_offset
=
torch
::
zeros
({
num_etypes
},
indptr_options
);
const
auto
substract_offset_data_ptr
=
substract_offset
.
data_ptr
<
indptr_t
>
();
const
auto
num_picked_offset_data_ptr
=
num_picked_offset_tensor
.
data_ptr
<
indptr_t
>
();
for
(
auto
i
=
0
;
i
<
num_etypes
;
++
i
)
{
num_picked_offset_data_ptr
[
i
+
1
]
=
etype_id_to_num_picked_offset
[
i
+
1
];
// Collect the total pick number for each edge type.
if
(
i
+
1
<
num_etypes
)
substract_offset_data_ptr
[
i
+
1
]
=
num_picked_neighbors_data_ptr
[
etype_id_to_num_picked_offset
[
i
]];
num_picked_neighbors_data_ptr
[
etype_id_to_num_picked_offset
[
i
]]
=
0
;
}
substract_offset
=
substract_offset
.
cumsum
(
0
,
indptr_
.
scalar_type
());
subgraph_indptr_substract
=
ops
::
ExpandIndptr
(
num_picked_offset_tensor
,
indptr_
.
scalar_type
(),
substract_offset
);
}
// Step 2. Calculate prefix sum to get total length and offsets of
// each node. It's also the indptr of the generated subgraph.
subgraph_indptr
=
num_picked_neighbors_per_node
.
cumsum
(
0
,
indptr_
.
scalar_type
());
auto
subgraph_indptr_data_ptr
=
subgraph_indptr
.
data_ptr
<
indptr_t
>
();
// When doing non-temporal hetero sampling, we generate an
// edge_offsets tensor.
if
(
hetero_with_seed_offsets
)
{
edge_offsets
=
torch
::
empty
({
num_etypes
+
1
},
indptr_options
);
AT_DISPATCH_INTEGRAL_TYPES
(
edge_offsets
.
value
().
scalar_type
(),
"CalculateEdgeOffsets"
,
([
&
]
{
auto
edge_offsets_data_ptr
=
edge_offsets
.
value
().
data_ptr
<
scalar_t
>
();
edge_offsets_data_ptr
[
0
]
=
0
;
for
(
auto
i
=
0
;
i
<
num_etypes
;
++
i
)
{
edge_offsets_data_ptr
[
i
+
1
]
=
subgraph_indptr_data_ptr
[
etype_id_to_num_picked_offset
[
i
+
1
]
-
1
];
}
}));
}
// Step 3. Allocate the tensor for picked neighbors.
const
auto
total_length
=
subgraph_indptr
.
data_ptr
<
indptr_t
>
()[
num_rows
-
1
];
picked_eids
=
torch
::
empty
({
total_length
},
indptr_options
);
subgraph_indices
=
torch
::
empty
({
total_length
},
indices_
.
options
());
if
(
!
hetero_with_seed_offsets
&&
type_per_edge_
.
has_value
())
{
subgraph_type_per_edge
=
torch
::
empty
(
{
total_length
},
type_per_edge_
.
value
().
options
());
}
auto
picked_eids_data_ptr
=
picked_eids
.
data_ptr
<
indptr_t
>
();
torch
::
parallel_for
(
0
,
num_seeds
,
grain_size
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
const
auto
nid
=
seeds_data_ptr
[
i
];
const
auto
offset
=
indptr_data
[
nid
];
const
auto
num_neighbors
=
indptr_data
[
nid
+
1
]
-
offset
;
auto
picked_number
=
0
;
const
auto
seed_type_id
=
(
hetero_with_seed_offsets
)
?
std
::
upper_bound
(
seed_offsets
->
begin
(),
seed_offsets
->
end
(),
i
)
-
seed_offsets
->
begin
()
-
1
:
0
;
const
auto
seed_index
=
(
hetero_with_seed_offsets
)
?
i
-
seed_offsets
->
at
(
seed_type_id
)
:
i
;
// Step 4. Pick neighbors for each node.
picked_number
=
pick_fn
(
offset
,
num_neighbors
,
picked_eids_data_ptr
,
seed_index
,
subgraph_indptr_data_ptr
,
etype_id_to_num_picked_offset
);
if
(
!
hetero_with_seed_offsets
)
{
TORCH_CHECK
(
num_picked_neighbors_data_ptr
[
i
+
1
]
==
picked_number
,
"Actual picked count doesn't match the calculated "
"pick number."
);
}
// Step 5. Calculate other attributes and return the
// subgraph.
if
(
picked_number
>
0
)
{
AT_DISPATCH_INDEX_TYPES
(
subgraph_indices
.
scalar_type
(),
"IndexSelectSubgraphIndices"
,
([
&
]
{
auto
subgraph_indices_data_ptr
=
subgraph_indices
.
data_ptr
<
index_t
>
();
auto
indices_data_ptr
=
indices_
.
data_ptr
<
index_t
>
();
for
(
auto
i
=
0
;
i
<
num_etypes
;
++
i
)
{
if
(
etype_id_to_dst_ntype_id
[
i
]
!=
seed_type_id
)
continue
;
const
auto
indptr_offset
=
with_seed_offsets
?
etype_id_to_num_picked_offset
[
i
]
+
seed_index
:
seed_index
;
const
auto
picked_begin
=
subgraph_indptr_data_ptr
[
indptr_offset
];
const
auto
picked_end
=
subgraph_indptr_data_ptr
[
indptr_offset
+
1
];
for
(
auto
j
=
picked_begin
;
j
<
picked_end
;
++
j
)
{
subgraph_indices_data_ptr
[
j
]
=
indices_data_ptr
[
picked_eids_data_ptr
[
j
]];
if
(
hetero_with_seed_offsets
&&
node_type_offset_
.
has_value
())
{
// Substract the node type offset from
// subgraph indices. Assuming
// node_type_offset has the same dtype as
// indices.
auto
node_type_offset_data
=
node_type_offset_
.
value
()
.
data_ptr
<
index_t
>
();
subgraph_indices_data_ptr
[
j
]
-=
node_type_offset_data
[
etype_id_to_src_ntype_id
[
i
]];
}
}
}
}));
if
(
!
hetero_with_seed_offsets
&&
type_per_edge_
.
has_value
())
{
// When hetero graph is sampled as a homo graph, we
// still generate type_per_edge tensor for this
// situation.
AT_DISPATCH_INTEGRAL_TYPES
(
subgraph_type_per_edge
.
value
().
scalar_type
(),
"IndexSelectTypePerEdge"
,
([
&
]
{
auto
subgraph_type_per_edge_data_ptr
=
subgraph_type_per_edge
.
value
()
.
data_ptr
<
scalar_t
>
();
auto
type_per_edge_data_ptr
=
type_per_edge_
.
value
().
data_ptr
<
scalar_t
>
();
const
auto
picked_offset
=
subgraph_indptr_data_ptr
[
seed_index
];
for
(
auto
j
=
picked_offset
;
j
<
picked_offset
+
picked_number
;
++
j
)
subgraph_type_per_edge_data_ptr
[
j
]
=
type_per_edge_data_ptr
[
picked_eids_data_ptr
[
j
]];
}));
}
}
}
});
}));
}));
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
if
(
subgraph_indptr_substract
.
has_value
())
{
subgraph_indptr
-=
subgraph_indptr_substract
.
value
();
}
return
c10
::
make_intrusive
<
FusedSampledSubgraph
>
(
subgraph_indptr
,
subgraph_indices
,
seeds
,
torch
::
nullopt
,
subgraph_reverse_edge_ids
,
subgraph_type_per_edge
,
edge_offsets
);
}
template
<
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
TemporalSampleNeighborsImpl
(
const
torch
::
Tensor
&
nodes
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
const
torch
::
Tensor
&
nodes
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
{
PickFn
pick_fn
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
...
@@ -663,6 +972,8 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -663,6 +972,8 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}
}
}
bool
with_seed_offsets
=
seed_offsets
.
has_value
();
if
(
layer
)
{
if
(
layer
)
{
if
(
random_seed
.
has_value
()
&&
random_seed
->
numel
()
>=
2
)
{
if
(
random_seed
.
has_value
()
&&
random_seed
->
numel
()
>=
2
)
{
SamplerArgs
<
SamplerType
::
LABOR_DEPENDENT
>
args
{
SamplerArgs
<
SamplerType
::
LABOR_DEPENDENT
>
args
{
...
@@ -670,11 +981,13 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -670,11 +981,13 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
{
random_seed
.
value
(),
static_cast
<
float
>
(
seed2_contribution
)},
{
random_seed
.
value
(),
static_cast
<
float
>
(
seed2_contribution
)},
NumNodes
()};
NumNodes
()};
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
(
seeds
.
value
(),
return_eids
,
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
with_seed_offsets
),
GetPickFn
(
GetPickFn
(
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
args
));
probs_or_mask
,
with_seed_offsets
,
args
));
}
else
{
}
else
{
auto
args
=
[
&
]
{
auto
args
=
[
&
]
{
if
(
random_seed
.
has_value
()
&&
random_seed
->
numel
()
==
1
)
{
if
(
random_seed
.
has_value
()
&&
random_seed
->
numel
()
==
1
)
{
...
@@ -689,20 +1002,23 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -689,20 +1002,23 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}
}();
}();
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
(
seeds
.
value
(),
return_eids
,
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
with_seed_offsets
),
GetPickFn
(
GetPickFn
(
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
args
));
probs_or_mask
,
with_seed_offsets
,
args
));
}
}
}
else
{
}
else
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
(
seeds
.
value
(),
return_eids
,
seeds
.
value
(),
seed_offsets
,
fanouts
,
return_eids
,
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
,
with_seed_offsets
),
GetPickFn
(
GetPickFn
(
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
args
));
with_seed_offsets
,
args
));
}
}
}
}
...
@@ -734,7 +1050,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
...
@@ -734,7 +1050,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
const
int64_t
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
const
int64_t
random_seed
=
RandomEngine
::
ThreadLocal
()
->
RandInt
(
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
static_cast
<
int64_t
>
(
0
),
std
::
numeric_limits
<
int64_t
>::
max
());
SamplerArgs
<
SamplerType
::
LABOR
>
args
{
indices_
,
random_seed
,
NumNodes
()};
SamplerArgs
<
SamplerType
::
LABOR
>
args
{
indices_
,
random_seed
,
NumNodes
()};
return
SampleNeighborsImpl
(
return
Temporal
SampleNeighborsImpl
(
input_nodes
,
return_eids
,
input_nodes
,
return_eids
,
GetTemporalNumPickFn
(
GetTemporalNumPickFn
(
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
...
@@ -745,7 +1061,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
...
@@ -745,7 +1061,7 @@ FusedCSCSamplingGraph::TemporalSampleNeighbors(
edge_timestamp
,
args
));
edge_timestamp
,
args
));
}
else
{
}
else
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
return
SampleNeighborsImpl
(
return
Temporal
SampleNeighborsImpl
(
input_nodes
,
return_eids
,
input_nodes
,
return_eids
,
GetTemporalNumPickFn
(
GetTemporalNumPickFn
(
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
input_nodes_timestamp
,
this
->
indices_
,
fanouts
,
replace
,
...
@@ -806,12 +1122,13 @@ void FusedCSCSamplingGraph::HoldSharedMemoryObject(
...
@@ -806,12 +1122,13 @@ void FusedCSCSamplingGraph::HoldSharedMemoryObject(
tensor_data_shm_
=
std
::
move
(
tensor_data_shm
);
tensor_data_shm_
=
std
::
move
(
tensor_data_shm
);
}
}
int64_t
NumPick
(
template
<
typename
PickedNumType
>
void
NumPick
(
int64_t
fanout
,
bool
replace
,
int64_t
fanout
,
bool
replace
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
int64_t
num_neighbors
)
{
int64_t
num_neighbors
,
PickedNumType
*
picked_num_ptr
)
{
int64_t
num_valid_neighbors
=
num_neighbors
;
int64_t
num_valid_neighbors
=
num_neighbors
;
if
(
probs_or_mask
.
has_value
())
{
if
(
probs_or_mask
.
has_value
()
&&
num_neighbors
>
0
)
{
// Subtract the count of zeros in probs_or_mask.
// Subtract the count of zeros in probs_or_mask.
AT_DISPATCH_ALL_TYPES
(
AT_DISPATCH_ALL_TYPES
(
probs_or_mask
.
value
().
scalar_type
(),
"CountZero"
,
([
&
]
{
probs_or_mask
.
value
().
scalar_type
(),
"CountZero"
,
([
&
]
{
...
@@ -821,8 +1138,11 @@ int64_t NumPick(
...
@@ -821,8 +1138,11 @@ int64_t NumPick(
0
);
0
);
}));
}));
}
}
if
(
num_valid_neighbors
==
0
||
fanout
==
-
1
)
return
num_valid_neighbors
;
if
(
num_valid_neighbors
==
0
||
fanout
==
-
1
)
{
return
replace
?
fanout
:
std
::
min
(
fanout
,
num_valid_neighbors
);
*
picked_num_ptr
=
num_valid_neighbors
;
}
else
{
*
picked_num_ptr
=
replace
?
fanout
:
std
::
min
(
fanout
,
num_valid_neighbors
);
}
}
}
torch
::
Tensor
TemporalMask
(
torch
::
Tensor
TemporalMask
(
...
@@ -926,14 +1246,16 @@ int64_t TemporalNumPick(
...
@@ -926,14 +1246,16 @@ int64_t TemporalNumPick(
return
replace
?
fanout
:
std
::
min
(
fanout
,
num_valid_neighbors
);
return
replace
?
fanout
:
std
::
min
(
fanout
,
num_valid_neighbors
);
}
}
int64_t
NumPickByEtype
(
template
<
typename
PickedNumType
>
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
void
NumPickByEtype
(
bool
with_seed_offsets
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
int64_t
offset
,
int64_t
num_neighbors
)
{
int64_t
num_neighbors
,
PickedNumType
*
num_picked_ptr
,
int64_t
seed_index
,
const
std
::
vector
<
int64_t
>&
etype_id_to_num_picked_offset
)
{
int64_t
etype_begin
=
offset
;
int64_t
etype_begin
=
offset
;
const
int64_t
end
=
offset
+
num_neighbors
;
const
int64_t
end
=
offset
+
num_neighbors
;
int64_t
total_count
=
0
;
PickedNumType
total_count
=
0
;
AT_DISPATCH_INTEGRAL_TYPES
(
AT_DISPATCH_INTEGRAL_TYPES
(
type_per_edge
.
scalar_type
(),
"NumPickFnByEtype"
,
([
&
]
{
type_per_edge
.
scalar_type
(),
"NumPickFnByEtype"
,
([
&
]
{
const
scalar_t
*
type_per_edge_data
=
type_per_edge
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
type_per_edge_data
=
type_per_edge
.
data_ptr
<
scalar_t
>
();
...
@@ -947,13 +1269,32 @@ int64_t NumPickByEtype(
...
@@ -947,13 +1269,32 @@ int64_t NumPickByEtype(
etype
);
etype
);
int64_t
etype_end
=
etype_end_it
-
type_per_edge_data
;
int64_t
etype_end
=
etype_end_it
-
type_per_edge_data
;
// Do sampling for one etype.
// Do sampling for one etype.
total_count
+=
NumPick
(
if
(
with_seed_offsets
)
{
fanouts
[
etype
],
replace
,
probs_or_mask
,
etype_begin
,
// The pick numbers aren't stored continuously, but separately for
etype_end
-
etype_begin
);
// each different etype.
const
auto
offset
=
etype_id_to_num_picked_offset
[
etype
]
+
seed_index
;
NumPick
(
fanouts
[
etype
],
replace
,
probs_or_mask
,
etype_begin
,
etype_end
-
etype_begin
,
num_picked_ptr
+
offset
);
// Use the skipped position of each edge type in the
// num_picked_tensor to sum up the total pick number for each edge
// type.
num_picked_ptr
[
etype_id_to_num_picked_offset
[
etype
]
-
1
]
+=
num_picked_ptr
[
offset
];
}
else
{
PickedNumType
picked_count
=
0
;
NumPick
(
fanouts
[
etype
],
replace
,
probs_or_mask
,
etype_begin
,
etype_end
-
etype_begin
,
&
picked_count
);
total_count
+=
picked_count
;
}
etype_begin
=
etype_end
;
etype_begin
=
etype_end
;
}
}
}));
}));
return
total_count
;
if
(
!
with_seed_offsets
)
{
num_picked_ptr
[
seed_index
]
=
total_count
;
}
}
}
int64_t
TemporalNumPickByEtype
(
int64_t
TemporalNumPickByEtype
(
...
@@ -1265,6 +1606,7 @@ int64_t Pick(
...
@@ -1265,6 +1606,7 @@ int64_t Pick(
const
torch
::
TensorOptions
&
options
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
,
PickedType
*
picked_data_ptr
)
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
,
PickedType
*
picked_data_ptr
)
{
if
(
fanout
==
0
||
num_neighbors
==
0
)
return
0
;
if
(
probs_or_mask
.
has_value
())
{
if
(
probs_or_mask
.
has_value
())
{
return
NonUniformPick
(
return
NonUniformPick
(
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
probs_or_mask
.
value
(),
offset
,
num_neighbors
,
fanout
,
replace
,
options
,
probs_or_mask
.
value
(),
...
@@ -1326,14 +1668,16 @@ int64_t TemporalPick(
...
@@ -1326,14 +1668,16 @@ int64_t TemporalPick(
template
<
SamplerType
S
,
typename
PickedType
>
template
<
SamplerType
S
,
typename
PickedType
>
int64_t
PickByEtype
(
int64_t
PickByEtype
(
int64_t
offset
,
int64_t
num_neighbors
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
with_seed_offsets
,
int64_t
offset
,
int64_t
num_neighbors
,
bool
replace
,
const
torch
::
TensorOptions
&
options
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
Tensor
&
type_per_edge
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
PickedType
*
picked_data_ptr
)
{
PickedType
*
picked_data_ptr
,
int64_t
seed_index
,
PickedType
*
subgraph_indptr_ptr
,
const
std
::
vector
<
int64_t
>&
etype_id_to_num_picked_offset
)
{
int64_t
etype_begin
=
offset
;
int64_t
etype_begin
=
offset
;
int64_t
etype_end
=
offset
;
int64_t
etype_end
=
offset
;
int64_t
pick
_offse
t
=
0
;
int64_t
pick
ed_total_coun
t
=
0
;
AT_DISPATCH_INTEGRAL_TYPES
(
AT_DISPATCH_INTEGRAL_TYPES
(
type_per_edge
.
scalar_type
(),
"PickByEtype"
,
([
&
]
{
type_per_edge
.
scalar_type
(),
"PickByEtype"
,
([
&
]
{
const
scalar_t
*
type_per_edge_data
=
type_per_edge
.
data_ptr
<
scalar_t
>
();
const
scalar_t
*
type_per_edge_data
=
type_per_edge
.
data_ptr
<
scalar_t
>
();
...
@@ -1348,17 +1692,36 @@ int64_t PickByEtype(
...
@@ -1348,17 +1692,36 @@ int64_t PickByEtype(
type_per_edge_data
+
etype_begin
,
type_per_edge_data
+
end
,
type_per_edge_data
+
etype_begin
,
type_per_edge_data
+
end
,
etype
);
etype
);
etype_end
=
etype_end_it
-
type_per_edge_data
;
etype_end
=
etype_end_it
-
type_per_edge_data
;
// Do sampling for one etype.
// Do sampling for one etype. The picked nodes aren't stored
// continuously, but separately for each different etype.
if
(
fanout
!=
0
)
{
if
(
fanout
!=
0
)
{
int64_t
picked_count
=
Pick
(
auto
picked_count
=
0
;
etype_begin
,
etype_end
-
etype_begin
,
fanout
,
replace
,
options
,
if
(
with_seed_offsets
)
{
probs_or_mask
,
args
,
picked_data_ptr
+
pick_offset
);
const
auto
indptr_offset
=
pick_offset
+=
picked_count
;
etype_id_to_num_picked_offset
[
etype
]
+
seed_index
;
picked_count
=
Pick
(
etype_begin
,
etype_end
-
etype_begin
,
fanout
,
replace
,
options
,
probs_or_mask
,
args
,
picked_data_ptr
+
subgraph_indptr_ptr
[
indptr_offset
]);
TORCH_CHECK
(
subgraph_indptr_ptr
[
indptr_offset
+
1
]
-
subgraph_indptr_ptr
[
indptr_offset
]
==
picked_count
,
"Actual picked count doesn't match the calculated "
"pick number."
);
}
else
{
picked_count
=
Pick
(
etype_begin
,
etype_end
-
etype_begin
,
fanout
,
replace
,
options
,
probs_or_mask
,
args
,
picked_data_ptr
+
subgraph_indptr_ptr
[
seed_index
]
+
picked_total_count
);
}
picked_total_count
+=
picked_count
;
}
}
etype_begin
=
etype_end
;
etype_begin
=
etype_end
;
}
}
}));
}));
return
pick
_offse
t
;
return
pick
ed_total_coun
t
;
}
}
template
<
SamplerType
S
,
typename
PickedType
>
template
<
SamplerType
S
,
typename
PickedType
>
...
@@ -1409,7 +1772,7 @@ std::enable_if_t<is_labor(S), int64_t> Pick(
...
@@ -1409,7 +1772,7 @@ std::enable_if_t<is_labor(S), int64_t> Pick(
const
torch
::
TensorOptions
&
options
,
const
torch
::
TensorOptions
&
options
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
const
torch
::
optional
<
torch
::
Tensor
>&
probs_or_mask
,
SamplerArgs
<
S
>
args
,
PickedType
*
picked_data_ptr
)
{
PickedType
*
picked_data_ptr
)
{
if
(
fanout
==
0
)
return
0
;
if
(
fanout
==
0
||
num_neighbors
==
0
)
return
0
;
if
(
probs_or_mask
.
has_value
())
{
if
(
probs_or_mask
.
has_value
())
{
if
(
fanout
<
0
)
{
if
(
fanout
<
0
)
{
return
NonUniformPick
(
return
NonUniformPick
(
...
...
tests/python/pytorch/graphbolt/impl/test_fused_csc_sampling_graph.py
View file @
658b2086
...
@@ -2219,10 +2219,13 @@ def test_sample_neighbors_hetero_pick_number(
...
@@ -2219,10 +2219,13 @@ def test_sample_neighbors_hetero_pick_number(
type_per_edge
=
type_per_edge
,
type_per_edge
=
type_per_edge
,
node_type_to_id
=
ntypes
,
node_type_to_id
=
ntypes
,
edge_type_to_id
=
etypes
,
edge_type_to_id
=
etypes
,
)
)
.
to
(
F
.
ctx
())
# Generate subgraph via sample neighbors.
# Generate subgraph via sample neighbors.
nodes
=
torch
.
LongTensor
([
0
,
1
])
nodes
=
{
"N0"
:
torch
.
LongTensor
([
0
]).
to
(
F
.
ctx
()),
"N1"
:
torch
.
LongTensor
([
1
]).
to
(
F
.
ctx
()),
}
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
sampler
=
graph
.
sample_layer_neighbors
if
labor
else
graph
.
sample_neighbors
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment