Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
cef5a14a
Unverified
Commit
cef5a14a
authored
Sep 22, 2023
by
peizhou001
Committed by
GitHub
Sep 22, 2023
Browse files
[Graphbolt] Add dgl minibatch (#6370)
parent
f8594230
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
84 additions
and
1 deletion
+84
-1
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+84
-1
No files found.
python/dgl/graphbolt/minibatch.py
View file @
cef5a14a
...
...
@@ -6,11 +6,94 @@ from typing import Dict, List, Tuple, Union
import
torch
import
dgl
from
dgl.heterograph
import
DGLBlock
from
.base
import
etype_str_to_tuple
from
.sampled_subgraph
import
SampledSubgraph
__all__
=
[
"MiniBatch"
]
__all__
=
[
"DGLMiniBatch"
,
"MiniBatch"
]
@
dataclass
class
DGLMiniBatch
:
r
"""A data class designed for the DGL library, encompassing all the
necessary fields for computation using the DGL library.."""
blocks
:
List
[
DGLBlock
]
=
None
"""A list of 'DGLBlock's, each one corresponding to one layer, representing
a bipartite graph used for message passing.
"""
input_nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""A representation of input nodes in the outermost layer. Conatins all
nodes in the 'blocks'.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
output_nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""Representation of output nodes, usually also the seed nodes, used for
sampling in the graph.
- If `output_nodes` is a tensor: It indicates the graph is homogeneous.
- If `output_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids.
"""
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
]
=
None
"""A representation of node features.
- If keys are single strings: It means the graph is homogeneous, and the
keys are feature names.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(node_type, feature_name)'.
"""
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
]
=
None
"""Edge features associated with the 'blocks'.
- If keys are single strings: It means the graph is homogeneous, and the
keys are feature names.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(edge_type, feature_name)'. Note, edge type is a triplet
of format (str, str, str).
"""
labels
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""Labels associated with seed nodes / node pairs in the graph.
- If `labels` is a tensor: It indicates the graph is homogeneous. The value
are corresponding labels to given 'output_nodes' or 'node_pairs'.
- If `labels` is a dictionary: The keys are node or edge type and the value
should be corresponding labels to given 'output_nodes' or 'node_pairs'.
"""
positive_node_pairs
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]
=
None
"""Representation of positive graphs used for evaluating or computing loss
in link prediction tasks.
- If `positive_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `positive_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
negative_node_pairs
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]
=
None
"""Representation of negative graphs used for evaluating or computing loss in
link prediction tasks.
- If `negative_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `negative_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
@
dataclass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment