Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
97ed294d
Unverified
Commit
97ed294d
authored
Dec 20, 2023
by
Ramon Zhou
Committed by
GitHub
Dec 20, 2023
Browse files
[GraphBolt] Delete DGLMiniBatch (#6760)
parent
b483c26f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
227 deletions
+4
-227
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+2
-225
python/dgl/graphbolt/minibatch_transformer.py
python/dgl/graphbolt/minibatch_transformer.py
+2
-2
No files found.
python/dgl/graphbolt/minibatch.py
View file @
97ed294d
"""Unified data structure for input and ouput of all the stages in loading process."""
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
torch
import
dgl
from
dgl.heterograph
import
DGLBlock
from
dgl.utils
import
recursive_apply
from
.base
import
CSCFormatBase
,
etype_str_to_tuple
from
.internal
import
get_attributes
from
.sampled_subgraph
import
SampledSubgraph
__all__
=
[
"DGLMiniBatch"
,
"MiniBatch"
]
@
dataclass
class
MiniBatchBase
(
object
):
"""Base class for `MiniBatch` and `DGLMiniBatch`."""
def
node_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the MiniBatch.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
raise
NotImplementedError
def
num_layers
(
self
)
->
int
:
"""Return the number of layers."""
raise
NotImplementedError
def
set_node_features
(
self
,
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
],
)
->
None
:
"""Set node features."""
raise
NotImplementedError
def
set_edge_features
(
self
,
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
],
)
->
None
:
"""Set edge features."""
raise
NotImplementedError
def
edge_ids
(
self
,
layer_id
:
int
)
->
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]:
"""Get the edge ids of a layer."""
raise
NotImplementedError
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy MiniBatch to the specified device."""
raise
NotImplementedError
@
dataclass
class
DGLMiniBatch
(
MiniBatchBase
):
r
"""A data class designed for the DGL library, encompassing all the
necessary fields for computation using the DGL library."""
blocks
:
List
[
DGLBlock
]
=
None
"""A list of 'DGLBlock's, each one corresponding to one layer, representing
a bipartite graph used for message passing.
"""
input_nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""A representation of input nodes in the outermost layer. Conatins all
nodes in the 'blocks'.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
output_nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""Representation of output nodes, usually also the seed nodes, used for
sampling in the graph.
- If `output_nodes` is a tensor: It indicates the graph is homogeneous.
- If `output_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids.
"""
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
]
=
None
"""A representation of node features.
- If keys are single strings: It means the graph is homogeneous, and the
keys are feature names.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(node_type, feature_name)'.
"""
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
]
=
None
"""Edge features associated with the 'blocks'.
- If keys are single strings: It means the graph is homogeneous, and the
keys are feature names.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(edge_type, feature_name)'. Note, edge type is a triplet
of format (str, str, str).
"""
labels
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""Labels associated with seed nodes / node pairs in the graph.
- If `labels` is a tensor: It indicates the graph is homogeneous. The value
are corresponding labels to given 'output_nodes' or 'node_pairs'.
- If `labels` is a dictionary: The keys are node or edge type and the value
should be corresponding labels to given 'output_nodes' or 'node_pairs'.
"""
positive_node_pairs
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]
=
None
"""Representation of positive graphs used for evaluating or computing loss
in link prediction tasks.
- If `positive_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `positive_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
negative_node_pairs
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]
=
None
"""Representation of negative graphs used for evaluating or computing loss in
link prediction tasks.
- If `negative_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `negative_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
def
__repr__
(
self
)
->
str
:
return
_dgl_minibatch_str
(
self
)
def
node_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the `blocks`.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
return
self
.
input_nodes
def
num_layers
(
self
)
->
int
:
"""Return the number of layers."""
if
self
.
blocks
is
None
:
return
0
return
len
(
self
.
blocks
)
def
edge_ids
(
self
,
layer_id
:
int
)
->
Optional
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]]:
"""Get edge ids of a layer."""
if
dgl
.
EID
not
in
self
.
blocks
[
layer_id
].
edata
:
return
None
return
self
.
blocks
[
layer_id
].
edata
[
dgl
.
EID
]
def
set_node_features
(
self
,
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
],
)
->
None
:
"""Set node features."""
self
.
node_features
=
node_features
def
set_edge_features
(
self
,
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
],
)
->
None
:
"""Set edge features."""
self
.
edge_features
=
edge_features
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy `DGLMiniBatch` to the specified device using reflection."""
def
_to
(
x
,
device
):
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
for
attr
in
dir
(
self
):
# Only copy member variables.
if
not
callable
(
getattr
(
self
,
attr
))
and
not
attr
.
startswith
(
"__"
):
setattr
(
self
,
attr
,
recursive_apply
(
getattr
(
self
,
attr
),
lambda
x
:
_to
(
x
,
device
)
),
)
return
self
__all__
=
[
"MiniBatch"
]
@
dataclass
...
...
@@ -654,34 +462,3 @@ def _minibatch_str(minibatch: MiniBatch) -> str:
final_str
+
f
"
{
name
}
=
{
_add_indent
(
val
,
len
(
name
)
+
1
)
}
,
\n
"
+
" "
*
10
)
return
"MiniBatch("
+
final_str
[:
-
3
]
+
")"
def
_dgl_minibatch_str
(
dglminibatch
:
DGLMiniBatch
)
->
str
:
final_str
=
""
# Get all attributes in the class except methods.
attributes
=
get_attributes
(
dglminibatch
)
attributes
.
reverse
()
# Insert key with its value into the string.
for
name
in
attributes
:
val
=
getattr
(
dglminibatch
,
name
)
def
_add_indent
(
_str
,
indent
):
lines
=
_str
.
split
(
"
\n
"
)
lines
=
[
lines
[
0
]]
+
[
" "
*
indent
+
line
for
line
in
lines
[
1
:]]
return
"
\n
"
.
join
(
lines
)
# Let the variables in the list occupy one line each, and adjust the
# indentation on top of the original if the original data output has
# line feeds.
if
isinstance
(
val
,
list
):
val
=
[
str
(
val_str
)
for
val_str
in
val
]
val
=
"["
+
",
\n
"
.
join
(
val
)
+
"]"
elif
isinstance
(
val
,
tuple
):
val
=
[
str
(
val_str
)
for
val_str
in
val
]
val
=
"("
+
",
\n
"
.
join
(
val
)
+
")"
else
:
val
=
str
(
val
)
final_str
=
(
final_str
+
f
"
{
name
}
=
{
_add_indent
(
val
,
len
(
name
)
+
15
)
}
,
\n
"
+
" "
*
13
)
return
"DGLMiniBatch("
+
final_str
[:
-
3
]
+
")"
python/dgl/graphbolt/minibatch_transformer.py
View file @
97ed294d
...
...
@@ -4,7 +4,7 @@ from torch.utils.data import functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
.minibatch
import
DGLMiniBatch
,
MiniBatch
from
.minibatch
import
MiniBatch
__all__
=
[
"MiniBatchTransformer"
,
...
...
@@ -37,6 +37,6 @@ class MiniBatchTransformer(Mapper):
def
_transformer
(
self
,
minibatch
):
minibatch
=
self
.
transformer
(
minibatch
)
assert
isinstance
(
minibatch
,
(
MiniBatch
,
DGLMiniBatch
)
minibatch
,
(
MiniBatch
,)
),
"The transformer output should be an instance of MiniBatch"
return
minibatch
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment