Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
02025989
Unverified
Commit
02025989
authored
Jun 06, 2023
by
Rhett Ying
Committed by
GitHub
Jun 06, 2023
Browse files
[Examples] refine dist train example (#5763)
parent
041f78ba
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
145 additions
and
91 deletions
+145
-91
examples/distributed/graphsage/node_classification.py
examples/distributed/graphsage/node_classification.py
+145
-91
No files found.
examples/distributed/graphsage/node_classification.py
View file @
02025989
...
@@ -12,18 +12,26 @@ import torch.optim as optim
...
@@ -12,18 +12,26 @@ import torch.optim as optim
import
tqdm
import
tqdm
def
load_subtensor
(
g
,
seeds
,
input_nodes
,
device
,
load_feat
=
Tru
e
):
class
DistSAGE
(
nn
.
Modul
e
):
"""
"""
Copys features and labels of a set of nodes onto GPU.
SAGE model for distributed train and evaluation.
Parameters
----------
in_feats : int
Feature dimension.
n_hidden : int
Hidden layer dimension.
n_classes : int
Number of classes.
n_layers : int
Number of layers.
activation : callable
Activation function.
dropout : float
Dropout value.
"""
"""
batch_inputs
=
(
g
.
ndata
[
"features"
][
input_nodes
].
to
(
device
)
if
load_feat
else
None
)
batch_labels
=
g
.
ndata
[
"labels"
][
seeds
].
to
(
device
)
return
batch_inputs
,
batch_labels
class
DistSAGE
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
in_feats
,
n_hidden
,
n_classes
,
n_layers
,
activation
,
dropout
self
,
in_feats
,
n_hidden
,
n_classes
,
n_layers
,
activation
,
dropout
):
):
...
@@ -40,6 +48,16 @@ class DistSAGE(nn.Module):
...
@@ -40,6 +48,16 @@ class DistSAGE(nn.Module):
self
.
activation
=
activation
self
.
activation
=
activation
def
forward
(
self
,
blocks
,
x
):
def
forward
(
self
,
blocks
,
x
):
"""
Forward function.
Parameters
----------
blocks : List[DGLBlock]
Sampled blocks.
x : DistTensor
Feature data.
"""
h
=
x
h
=
x
for
i
,
(
layer
,
block
)
in
enumerate
(
zip
(
self
.
layers
,
blocks
)):
for
i
,
(
layer
,
block
)
in
enumerate
(
zip
(
self
.
layers
,
blocks
)):
h
=
layer
(
block
,
h
)
h
=
layer
(
block
,
h
)
...
@@ -50,41 +68,46 @@ class DistSAGE(nn.Module):
...
@@ -50,41 +68,46 @@ class DistSAGE(nn.Module):
def
inference
(
self
,
g
,
x
,
batch_size
,
device
):
def
inference
(
self
,
g
,
x
,
batch_size
,
device
):
"""
"""
Inference with the GraphSAGE model on full neighbors (i.e. without
Distributed layer-wise inference with the GraphSAGE model on full
neighbor sampling).
neighbors.
g : the entire graph.
Parameters
x : the input of entire node set.
----------
g : DistGraph
Distributed layer-wise inference.
Input Graph for inference.
x : DistTensor
Node feature data of input graph.
Returns
-------
DistTensor
Inference results.
"""
"""
# During inference with sampling, multi-layer blocks are very
# Split nodes to each trainer.
# inefficient because lots of computations in the first few layers
# are repeated. Therefore, we compute the representation of all nodes
# layer by layer. The nodes on each layer are of course splitted in
# batches.
# TODO: can we standardize this?
nodes
=
dgl
.
distributed
.
node_split
(
nodes
=
dgl
.
distributed
.
node_split
(
np
.
arange
(
g
.
num_nodes
()),
np
.
arange
(
g
.
num_nodes
()),
g
.
get_partition_book
(),
g
.
get_partition_book
(),
force_even
=
True
,
force_even
=
True
,
)
)
y
=
dgl
.
distributed
.
DistTensor
(
(
g
.
num_nodes
(),
self
.
n_hidden
),
th
.
float32
,
"h"
,
persistent
=
True
,
)
for
i
,
layer
in
enumerate
(
self
.
layers
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
# Create DistTensor to save forward results.
if
i
==
len
(
self
.
layers
)
-
1
:
if
i
==
len
(
self
.
layers
)
-
1
:
out_dim
=
self
.
n_classes
name
=
"h_last"
else
:
out_dim
=
self
.
n_hidden
name
=
"h"
y
=
dgl
.
distributed
.
DistTensor
(
y
=
dgl
.
distributed
.
DistTensor
(
(
g
.
num_nodes
(),
self
.
n_classes
),
(
g
.
num_nodes
(),
out_dim
),
th
.
float32
,
th
.
float32
,
"h_last"
,
name
,
persistent
=
True
,
persistent
=
True
,
)
)
print
(
f
"|V|=
{
g
.
num_nodes
()
}
,
eval
batch size:
{
batch_size
}
"
)
print
(
f
"|V|=
{
g
.
num_nodes
()
}
,
inference
batch size:
{
batch_size
}
"
)
# `-1` indicates all inbound edges will be inlcuded, namely, full
# neighbor sampling.
sampler
=
dgl
.
dataloading
.
NeighborSampler
([
-
1
])
sampler
=
dgl
.
dataloading
.
NeighborSampler
([
-
1
])
dataloader
=
dgl
.
dataloading
.
DistNodeDataLoader
(
dataloader
=
dgl
.
dataloading
.
DistNodeDataLoader
(
g
,
g
,
...
@@ -103,17 +126,30 @@ class DistSAGE(nn.Module):
...
@@ -103,17 +126,30 @@ class DistSAGE(nn.Module):
if
i
!=
len
(
self
.
layers
)
-
1
:
if
i
!=
len
(
self
.
layers
)
-
1
:
h
=
self
.
activation
(
h
)
h
=
self
.
activation
(
h
)
h
=
self
.
dropout
(
h
)
h
=
self
.
dropout
(
h
)
# Copy back to CPU as DistTensor requires data reside on CPU.
y
[
output_nodes
]
=
h
.
cpu
()
y
[
output_nodes
]
=
h
.
cpu
()
x
=
y
x
=
y
# Synchronize trainers.
g
.
barrier
()
g
.
barrier
()
return
y
return
x
def
compute_acc
(
pred
,
labels
):
def
compute_acc
(
pred
,
labels
):
"""
"""
Compute the accuracy of prediction given the labels.
Compute the accuracy of prediction given the labels.
Parameters
----------
pred : torch.Tensor
Predicted labels.
labels : torch.Tensor
Ground-truth labels.
Returns
-------
float
Accuracy.
"""
"""
labels
=
labels
.
long
()
labels
=
labels
.
long
()
return
(
th
.
argmax
(
pred
,
dim
=
1
)
==
labels
).
float
().
sum
()
/
len
(
pred
)
return
(
th
.
argmax
(
pred
,
dim
=
1
)
==
labels
).
float
().
sum
()
/
len
(
pred
)
...
@@ -121,13 +157,33 @@ def compute_acc(pred, labels):
...
@@ -121,13 +157,33 @@ def compute_acc(pred, labels):
def
evaluate
(
model
,
g
,
inputs
,
labels
,
val_nid
,
test_nid
,
batch_size
,
device
):
def
evaluate
(
model
,
g
,
inputs
,
labels
,
val_nid
,
test_nid
,
batch_size
,
device
):
"""
"""
Evaluate the model on the validation set specified by ``val_nid``.
Evaluate the model on the validation and test set.
g : The entire graph.
inputs : The features of all the nodes.
Parameters
labels : The labels of all the nodes.
----------
val_nid : the node Ids for validation.
model : DistSAGE
batch_size : Number of nodes to compute at the same time.
The model to be evaluated.
device : The GPU device to evaluate on.
g : DistGraph
The entire graph.
inputs : DistTensor
The feature data of all the nodes.
labels : DistTensor
The labels of all the nodes.
val_nid : torch.Tensor
The node IDs for validation.
test_nid : torch.Tensor
The node IDs for test.
batch_size : int
Batch size for evaluation.
device : torch.Device
The target device to evaluate on.
Returns
-------
float
Validation accuracy.
float
Test accuracy.
"""
"""
model
.
eval
()
model
.
eval
()
with
th
.
no_grad
():
with
th
.
no_grad
():
...
@@ -139,6 +195,19 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
...
@@ -139,6 +195,19 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
def
run
(
args
,
device
,
data
):
def
run
(
args
,
device
,
data
):
"""
Train and evaluate DistSAGE.
Parameters
----------
args : argparse.Args
Arguments for train and evaluate.
device : torch.Device
Target device for train and evaluate.
data : Packed Data
Packed data includes train/val/test IDs, feature dimension,
number of classes, graph.
"""
train_nid
,
val_nid
,
test_nid
,
in_feats
,
n_classes
,
g
=
data
train_nid
,
val_nid
,
test_nid
,
in_feats
,
n_classes
,
g
=
data
sampler
=
dgl
.
dataloading
.
NeighborSampler
(
sampler
=
dgl
.
dataloading
.
NeighborSampler
(
[
int
(
fanout
)
for
fanout
in
args
.
fan_out
.
split
(
","
)]
[
int
(
fanout
)
for
fanout
in
args
.
fan_out
.
split
(
","
)]
...
@@ -178,6 +247,7 @@ def run(args, device, data):
...
@@ -178,6 +247,7 @@ def run(args, device, data):
for
_
in
range
(
args
.
num_epochs
):
for
_
in
range
(
args
.
num_epochs
):
epoch
+=
1
epoch
+=
1
tic
=
time
.
time
()
tic
=
time
.
time
()
# Various time statistics.
sample_time
=
0
sample_time
=
0
forward_time
=
0
forward_time
=
0
backward_time
=
0
backward_time
=
0
...
@@ -185,18 +255,15 @@ def run(args, device, data):
...
@@ -185,18 +255,15 @@ def run(args, device, data):
num_seeds
=
0
num_seeds
=
0
num_inputs
=
0
num_inputs
=
0
start
=
time
.
time
()
start
=
time
.
time
()
# Loop over the dataloader to sample the computation dependency graph
# as a list of blocks.
step_time
=
[]
step_time
=
[]
with
model
.
join
():
with
model
.
join
():
for
step
,
(
input_nodes
,
seeds
,
blocks
)
in
enumerate
(
dataloader
):
for
step
,
(
input_nodes
,
seeds
,
blocks
)
in
enumerate
(
dataloader
):
tic_step
=
time
.
time
()
tic_step
=
time
.
time
()
sample_time
+=
tic_step
-
start
sample_time
+=
tic_step
-
start
batch_inputs
,
batch_labels
=
load_subtensor
(
# Slice feature and label.
g
,
seeds
,
input_nodes
,
"cpu"
batch_inputs
=
g
.
ndata
[
"features"
][
input_nodes
]
)
batch_labels
=
g
.
ndata
[
"labels"
][
seeds
].
long
()
batch_labels
=
batch_labels
.
long
()
num_seeds
+=
len
(
blocks
[
-
1
].
dstdata
[
dgl
.
NID
])
num_seeds
+=
len
(
blocks
[
-
1
].
dstdata
[
dgl
.
NID
])
num_inputs
+=
len
(
blocks
[
0
].
srcdata
[
dgl
.
NID
])
num_inputs
+=
len
(
blocks
[
0
].
srcdata
[
dgl
.
NID
])
# Move to target device.
# Move to target device.
...
@@ -227,36 +294,23 @@ def run(args, device, data):
...
@@ -227,36 +294,23 @@ def run(args, device, data):
if
th
.
cuda
.
is_available
()
if
th
.
cuda
.
is_available
()
else
0
else
0
)
)
sample_speed
=
np
.
mean
(
iter_tput
[
-
args
.
log_every
:])
mean_step_time
=
np
.
mean
(
step_time
[
-
args
.
log_every
:])
print
(
print
(
"Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | "
f
"Part
{
g
.
rank
()
}
| Epoch
{
epoch
:
05
d
}
| Step
{
step
:
05
d
}
"
"Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU "
f
" | Loss
{
loss
.
item
():.
4
f
}
| Train Acc
{
acc
.
item
():.
4
f
}
"
"{:.1f} MB | time {:.3f} s"
.
format
(
f
" | Speed (samples/sec)
{
sample_speed
:.
4
f
}
"
g
.
rank
(),
f
" | GPU
{
gpu_mem_alloc
:.
1
f
}
MB | "
epoch
,
f
"Mean step time
{
mean_step_time
:.
3
f
}
s"
step
,
loss
.
item
(),
acc
.
item
(),
np
.
mean
(
iter_tput
[
3
:]),
gpu_mem_alloc
,
np
.
mean
(
step_time
[
-
args
.
log_every
:]),
)
)
)
start
=
time
.
time
()
start
=
time
.
time
()
toc
=
time
.
time
()
toc
=
time
.
time
()
print
(
print
(
"Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, "
f
"Part
{
g
.
rank
()
}
, Epoch Time(s):
{
toc
-
tic
:.
4
f
}
, "
"forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, "
f
"sample+data_copy:
{
sample_time
:.
4
f
}
, forward:
{
forward_time
:.
4
f
}
,"
"#inputs: {}"
.
format
(
f
" backward:
{
backward_time
:.
4
f
}
, update:
{
update_time
:.
4
f
}
, "
g
.
rank
(),
f
"#seeds:
{
num_seeds
}
, #inputs:
{
num_inputs
}
"
toc
-
tic
,
sample_time
,
forward_time
,
backward_time
,
update_time
,
num_seeds
,
num_inputs
,
)
)
)
epoch_time
.
append
(
toc
-
tic
)
epoch_time
.
append
(
toc
-
tic
)
...
@@ -273,23 +327,27 @@ def run(args, device, data):
...
@@ -273,23 +327,27 @@ def run(args, device, data):
device
,
device
,
)
)
print
(
print
(
"Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}"
.
format
(
f
"Part
{
g
.
rank
()
}
, Val Acc
{
val_acc
:.
4
f
}
, "
g
.
rank
(),
val_acc
,
test_acc
,
time
.
time
()
-
start
f
"Test Acc
{
test_acc
:.
4
f
}
, time:
{
time
.
time
()
-
start
:.
4
f
}
"
)
)
)
return
np
.
mean
(
epoch_time
[
-
int
(
args
.
num_epochs
*
0.8
)
:]),
test_acc
return
np
.
mean
(
epoch_time
[
-
int
(
args
.
num_epochs
*
0.8
)
:]),
test_acc
def
main
(
args
):
def
main
(
args
):
print
(
socket
.
gethostname
(),
"Initializing DistDGL."
)
"""
Main function.
"""
host_name
=
socket
.
gethostname
()
print
(
f
"
{
host_name
}
: Initializing DistDGL."
)
dgl
.
distributed
.
initialize
(
args
.
ip_config
,
net_type
=
args
.
net_type
)
dgl
.
distributed
.
initialize
(
args
.
ip_config
,
net_type
=
args
.
net_type
)
print
(
socket
.
get
hostname
(),
"
Initializing PyTorch process group."
)
print
(
f
"
{
host
_
name
}
:
Initializing PyTorch process group."
)
th
.
distributed
.
init_process_group
(
backend
=
args
.
backend
)
th
.
distributed
.
init_process_group
(
backend
=
args
.
backend
)
print
(
socket
.
get
hostname
(),
"
Initializing DistGraph."
)
print
(
f
"
{
host
_
name
}
:
Initializing DistGraph."
)
g
=
dgl
.
distributed
.
DistGraph
(
args
.
graph_name
,
part_config
=
args
.
part_config
)
g
=
dgl
.
distributed
.
DistGraph
(
args
.
graph_name
,
part_config
=
args
.
part_config
)
print
(
socket
.
get
hostname
(),
"rank:"
,
g
.
rank
())
print
(
f
"Rank of
{
host
_
name
}
:
{
g
.
rank
()
}
"
)
# Split train/val/test IDs for each trainer.
pb
=
g
.
get_partition_book
()
pb
=
g
.
get_partition_book
()
if
"trainer_id"
in
g
.
ndata
:
if
"trainer_id"
in
g
.
ndata
:
train_nid
=
dgl
.
distributed
.
node_split
(
train_nid
=
dgl
.
distributed
.
node_split
(
...
@@ -321,17 +379,13 @@ def main(args):
...
@@ -321,17 +379,13 @@ def main(args):
g
.
ndata
[
"test_mask"
],
pb
,
force_even
=
True
g
.
ndata
[
"test_mask"
],
pb
,
force_even
=
True
)
)
local_nid
=
pb
.
partid2nids
(
pb
.
partid
).
detach
().
numpy
()
local_nid
=
pb
.
partid2nids
(
pb
.
partid
).
detach
().
numpy
()
num_train_local
=
len
(
np
.
intersect1d
(
train_nid
.
numpy
(),
local_nid
))
num_val_local
=
len
(
np
.
intersect1d
(
val_nid
.
numpy
(),
local_nid
))
num_test_local
=
len
(
np
.
intersect1d
(
test_nid
.
numpy
(),
local_nid
))
print
(
print
(
"part {}, train: {} (local: {}), val: {} (local: {}), test: {} "
f
"part
{
g
.
rank
()
}
, train:
{
len
(
train_nid
)
}
(local:
{
num_train_local
}
), "
"(local: {})"
.
format
(
f
"val:
{
len
(
val_nid
)
}
(local:
{
num_val_local
}
), "
g
.
rank
(),
f
"test:
{
len
(
test_nid
)
}
(local:
{
num_test_local
}
)"
len
(
train_nid
),
len
(
np
.
intersect1d
(
train_nid
.
numpy
(),
local_nid
)),
len
(
val_nid
),
len
(
np
.
intersect1d
(
val_nid
.
numpy
(),
local_nid
)),
len
(
test_nid
),
len
(
np
.
intersect1d
(
test_nid
.
numpy
(),
local_nid
)),
)
)
)
del
local_nid
del
local_nid
if
args
.
num_gpus
==
0
:
if
args
.
num_gpus
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment