Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e6a15c1a
Unverified
Commit
e6a15c1a
authored
Oct 24, 2023
by
Mingbang Wang
Committed by
GitHub
Oct 24, 2023
Browse files
[GraphBolt] Enable `node_classification.py` to run on GPU environment (#6490)
parent
2a92dfca
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
8 deletions
+44
-8
examples/sampling/graphbolt/node_classification.py
examples/sampling/graphbolt/node_classification.py
+44
-8
No files found.
examples/sampling/graphbolt/node_classification.py
View file @
e6a15c1a
...
...
@@ -47,7 +47,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchmetrics.functional
as
MF
import
tqdm
from
tqdm
import
tqdm
def
create_dataloader
(
...
...
@@ -145,6 +145,16 @@ def create_dataloader(
############################################################################
# [Step-5]:
# self.copy_to()
# [Input]:
# 'device': The device to copy the data to.
# [Output]:
# A CopyTo object to copy the data to the specified device.
############################################################################
datapipe
=
datapipe
.
copy_to
(
device
=
args
.
device
)
############################################################################
# [Step-6]:
# gb.MultiProcessDataLoader()
# [Input]:
# 'datapipe': The datapipe object to be used for data loading.
...
...
@@ -191,10 +201,15 @@ class SAGE(nn.Module):
hidden_x
=
self
.
dropout
(
hidden_x
)
return
hidden_x
def
inference
(
self
,
graph
,
features
,
dataloader
):
def
inference
(
self
,
graph
,
features
,
dataloader
,
device
):
"""Conduct layer-wise inference to get all the node embeddings."""
feature
=
features
.
read
(
"node"
,
None
,
"feat"
)
buffer_device
=
torch
.
device
(
"cpu"
)
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory
=
buffer_device
!=
device
for
layer_idx
,
layer
in
enumerate
(
self
.
layers
):
is_last_layer
=
layer_idx
==
len
(
self
.
layers
)
-
1
...
...
@@ -202,16 +217,21 @@ class SAGE(nn.Module):
graph
.
total_num_nodes
,
self
.
out_size
if
is_last_layer
else
self
.
hidden_size
,
dtype
=
torch
.
float64
,
device
=
buffer_device
,
pin_memory
=
pin_memory
,
)
feature
=
feature
.
to
(
device
)
for
step
,
data
in
tqdm
.
tqdm
(
enumerate
(
dataloader
)):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
x
=
feature
[
data
.
input_nodes
]
hidden_x
=
layer
(
data
.
blocks
[
0
],
x
)
# len(blocks) = 1
if
not
is_last_layer
:
hidden_x
=
F
.
relu
(
hidden_x
)
hidden_x
=
self
.
dropout
(
hidden_x
)
# By design, our output nodes are contiguous.
y
[
data
.
output_nodes
[
0
]
:
data
.
output_nodes
[
-
1
]
+
1
]
=
hidden_x
y
[
data
.
output_nodes
[
0
]
:
data
.
output_nodes
[
-
1
]
+
1
]
=
hidden_x
.
to
(
buffer_device
)
feature
=
y
return
y
...
...
@@ -225,7 +245,7 @@ def layerwise_infer(
dataloader
=
create_dataloader
(
args
,
graph
,
features
,
all_nodes_set
,
job
=
"infer"
)
pred
=
model
.
inference
(
graph
,
features
,
dataloader
)
pred
=
model
.
inference
(
graph
,
features
,
dataloader
,
args
.
device
)
pred
=
pred
[
test_set
.
_items
[
0
]]
label
=
test_set
.
_items
[
1
].
to
(
pred
.
device
)
...
...
@@ -246,7 +266,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
args
,
graph
,
features
,
itemset
,
job
=
"evaluate"
)
for
step
,
data
in
tqdm
.
tqdm
(
enumerate
(
dataloader
)):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
x
=
data
.
node_features
[
"feat"
]
y
.
append
(
data
.
labels
)
y_hats
.
append
(
model
(
data
.
blocks
,
x
))
...
...
@@ -265,10 +285,10 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
args
,
graph
,
features
,
train_set
,
job
=
"train"
)
for
epoch
in
tqdm
.
t
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
model
.
train
()
total_loss
=
0
for
step
,
data
in
tqdm
.
tqdm
(
enumerate
(
dataloader
)):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
# The input features from the source nodes in the first layer's
# computation graph.
x
=
data
.
node_features
[
"feat"
]
...
...
@@ -326,14 +346,28 @@ def parse_args():
help
=
"Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 15,10,5"
,
)
parser
.
add_argument
(
"--device"
,
default
=
"cpu"
,
choices
=
[
"cpu"
,
"cuda"
],
help
=
"Train device: 'cpu' for CPU, 'cuda' for GPU."
,
)
return
parser
.
parse_args
()
def
main
(
args
):
if
not
torch
.
cuda
.
is_available
():
args
.
device
=
"cpu"
print
(
f
"Training in
{
args
.
device
}
mode."
)
args
.
device
=
torch
.
device
(
args
.
device
)
# Load and preprocess dataset.
print
(
"Loading data..."
)
dataset
=
gb
.
BuiltinDataset
(
"ogbn-products"
).
load
()
graph
=
dataset
.
graph
# Currently the neighbor-sampling process can only be done on the CPU,
# therefore there is no need to copy the graph to the GPU.
features
=
dataset
.
feature
train_set
=
dataset
.
tasks
[
0
].
train_set
valid_set
=
dataset
.
tasks
[
0
].
validation_set
...
...
@@ -348,6 +382,8 @@ def main(args):
out_size
=
num_classes
model
=
SAGE
(
in_size
,
hidden_size
,
out_size
)
assert
len
(
args
.
fanout
)
==
len
(
model
.
layers
)
model
=
model
.
to
(
args
.
device
)
# Model training.
print
(
"Training..."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment