Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
97ed294d
Unverified
Commit
97ed294d
authored
Dec 20, 2023
by
Ramon Zhou
Committed by
GitHub
Dec 20, 2023
Browse files
[GraphBolt] Delete DGLMiniBatch (#6760)
parent
b483c26f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
227 deletions
+4
-227
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+2
-225
python/dgl/graphbolt/minibatch_transformer.py
python/dgl/graphbolt/minibatch_transformer.py
+2
-2
No files found.
python/dgl/graphbolt/minibatch.py
View file @
97ed294d
"""Unified data structure for input and ouput of all the stages in loading process."""
"""Unified data structure for input and ouput of all the stages in loading process."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
torch
import
torch
import
dgl
import
dgl
from
dgl.heterograph
import
DGLBlock
from
dgl.utils
import
recursive_apply
from
dgl.utils
import
recursive_apply
from
.base
import
CSCFormatBase
,
etype_str_to_tuple
from
.base
import
CSCFormatBase
,
etype_str_to_tuple
from
.internal
import
get_attributes
from
.internal
import
get_attributes
from
.sampled_subgraph
import
SampledSubgraph
from
.sampled_subgraph
import
SampledSubgraph
__all__
=
[
"DGLMiniBatch"
,
"MiniBatch"
]
__all__
=
[
"MiniBatch"
]
@
dataclass
class
MiniBatchBase
(
object
):
"""Base class for `MiniBatch` and `DGLMiniBatch`."""
def
node_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the MiniBatch.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
raise
NotImplementedError
def
num_layers
(
self
)
->
int
:
"""Return the number of layers."""
raise
NotImplementedError
def
set_node_features
(
self
,
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
],
)
->
None
:
"""Set node features."""
raise
NotImplementedError
def
set_edge_features
(
self
,
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
],
)
->
None
:
"""Set edge features."""
raise
NotImplementedError
def
edge_ids
(
self
,
layer_id
:
int
)
->
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]:
"""Get the edge ids of a layer."""
raise
NotImplementedError
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy MiniBatch to the specified device."""
raise
NotImplementedError
@
dataclass
class
DGLMiniBatch
(
MiniBatchBase
):
r
"""A data class designed for the DGL library, encompassing all the
necessary fields for computation using the DGL library."""
blocks
:
List
[
DGLBlock
]
=
None
"""A list of 'DGLBlock's, each one corresponding to one layer, representing
a bipartite graph used for message passing.
"""
input_nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""A representation of input nodes in the outermost layer. Conatins all
nodes in the 'blocks'.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
output_nodes
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""Representation of output nodes, usually also the seed nodes, used for
sampling in the graph.
- If `output_nodes` is a tensor: It indicates the graph is homogeneous.
- If `output_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids.
"""
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
]
=
None
"""A representation of node features.
- If keys are single strings: It means the graph is homogeneous, and the
keys are feature names.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(node_type, feature_name)'.
"""
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
]
=
None
"""Edge features associated with the 'blocks'.
- If keys are single strings: It means the graph is homogeneous, and the
keys are feature names.
- If keys are tuples: It means the graph is heterogeneous, and the keys
are tuples of '(edge_type, feature_name)'. Note, edge type is a triplet
of format (str, str, str).
"""
labels
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]
=
None
"""Labels associated with seed nodes / node pairs in the graph.
- If `labels` is a tensor: It indicates the graph is homogeneous. The value
are corresponding labels to given 'output_nodes' or 'node_pairs'.
- If `labels` is a dictionary: The keys are node or edge type and the value
should be corresponding labels to given 'output_nodes' or 'node_pairs'.
"""
positive_node_pairs
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]
=
None
"""Representation of positive graphs used for evaluating or computing loss
in link prediction tasks.
- If `positive_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `positive_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
negative_node_pairs
:
Union
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
Dict
[
str
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
]
=
None
"""Representation of negative graphs used for evaluating or computing loss in
link prediction tasks.
- If `negative_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `negative_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
def
__repr__
(
self
)
->
str
:
return
_dgl_minibatch_str
(
self
)
def
node_ids
(
self
)
->
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]]:
"""A representation of input nodes in the outermost layer. Contains all
nodes in the `blocks`.
- If `input_nodes` is a tensor: It indicates the graph is homogeneous.
- If `input_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node id.
"""
return
self
.
input_nodes
def
num_layers
(
self
)
->
int
:
"""Return the number of layers."""
if
self
.
blocks
is
None
:
return
0
return
len
(
self
.
blocks
)
def
edge_ids
(
self
,
layer_id
:
int
)
->
Optional
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
torch
.
Tensor
]]:
"""Get edge ids of a layer."""
if
dgl
.
EID
not
in
self
.
blocks
[
layer_id
].
edata
:
return
None
return
self
.
blocks
[
layer_id
].
edata
[
dgl
.
EID
]
def
set_node_features
(
self
,
node_features
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]
],
)
->
None
:
"""Set node features."""
self
.
node_features
=
node_features
def
set_edge_features
(
self
,
edge_features
:
List
[
Union
[
Dict
[
str
,
torch
.
Tensor
],
Dict
[
Tuple
[
str
,
str
],
torch
.
Tensor
]]
],
)
->
None
:
"""Set edge features."""
self
.
edge_features
=
edge_features
def
to
(
self
,
device
:
torch
.
device
)
->
None
:
# pylint: disable=invalid-name
"""Copy `DGLMiniBatch` to the specified device using reflection."""
def
_to
(
x
,
device
):
return
x
.
to
(
device
)
if
hasattr
(
x
,
"to"
)
else
x
for
attr
in
dir
(
self
):
# Only copy member variables.
if
not
callable
(
getattr
(
self
,
attr
))
and
not
attr
.
startswith
(
"__"
):
setattr
(
self
,
attr
,
recursive_apply
(
getattr
(
self
,
attr
),
lambda
x
:
_to
(
x
,
device
)
),
)
return
self
@
dataclass
@
dataclass
...
@@ -654,34 +462,3 @@ def _minibatch_str(minibatch: MiniBatch) -> str:
...
@@ -654,34 +462,3 @@ def _minibatch_str(minibatch: MiniBatch) -> str:
final_str
+
f
"
{
name
}
=
{
_add_indent
(
val
,
len
(
name
)
+
1
)
}
,
\n
"
+
" "
*
10
final_str
+
f
"
{
name
}
=
{
_add_indent
(
val
,
len
(
name
)
+
1
)
}
,
\n
"
+
" "
*
10
)
)
return
"MiniBatch("
+
final_str
[:
-
3
]
+
")"
return
"MiniBatch("
+
final_str
[:
-
3
]
+
")"
def
_dgl_minibatch_str
(
dglminibatch
:
DGLMiniBatch
)
->
str
:
final_str
=
""
# Get all attributes in the class except methods.
attributes
=
get_attributes
(
dglminibatch
)
attributes
.
reverse
()
# Insert key with its value into the string.
for
name
in
attributes
:
val
=
getattr
(
dglminibatch
,
name
)
def
_add_indent
(
_str
,
indent
):
lines
=
_str
.
split
(
"
\n
"
)
lines
=
[
lines
[
0
]]
+
[
" "
*
indent
+
line
for
line
in
lines
[
1
:]]
return
"
\n
"
.
join
(
lines
)
# Let the variables in the list occupy one line each, and adjust the
# indentation on top of the original if the original data output has
# line feeds.
if
isinstance
(
val
,
list
):
val
=
[
str
(
val_str
)
for
val_str
in
val
]
val
=
"["
+
",
\n
"
.
join
(
val
)
+
"]"
elif
isinstance
(
val
,
tuple
):
val
=
[
str
(
val_str
)
for
val_str
in
val
]
val
=
"("
+
",
\n
"
.
join
(
val
)
+
")"
else
:
val
=
str
(
val
)
final_str
=
(
final_str
+
f
"
{
name
}
=
{
_add_indent
(
val
,
len
(
name
)
+
15
)
}
,
\n
"
+
" "
*
13
)
return
"DGLMiniBatch("
+
final_str
[:
-
3
]
+
")"
python/dgl/graphbolt/minibatch_transformer.py
View file @
97ed294d
...
@@ -4,7 +4,7 @@ from torch.utils.data import functional_datapipe
...
@@ -4,7 +4,7 @@ from torch.utils.data import functional_datapipe
from
torchdata.datapipes.iter
import
Mapper
from
torchdata.datapipes.iter
import
Mapper
from
.minibatch
import
DGLMiniBatch
,
MiniBatch
from
.minibatch
import
MiniBatch
__all__
=
[
__all__
=
[
"MiniBatchTransformer"
,
"MiniBatchTransformer"
,
...
@@ -37,6 +37,6 @@ class MiniBatchTransformer(Mapper):
...
@@ -37,6 +37,6 @@ class MiniBatchTransformer(Mapper):
def
_transformer
(
self
,
minibatch
):
def
_transformer
(
self
,
minibatch
):
minibatch
=
self
.
transformer
(
minibatch
)
minibatch
=
self
.
transformer
(
minibatch
)
assert
isinstance
(
assert
isinstance
(
minibatch
,
(
MiniBatch
,
DGLMiniBatch
)
minibatch
,
(
MiniBatch
,)
),
"The transformer output should be an instance of MiniBatch"
),
"The transformer output should be an instance of MiniBatch"
return
minibatch
return
minibatch
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment