Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
78fa316a
Unverified
Commit
78fa316a
authored
Jan 18, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Jan 18, 2024
Browse files
[GraphBolt][CUDA] Modify multiGPU example to use GPU sampling. (#6961)
parent
f6db850d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
14 deletions
+25
-14
examples/multigpu/graphbolt/node_classification.py
examples/multigpu/graphbolt/node_classification.py
+25
-14
No files found.
examples/multigpu/graphbolt/node_classification.py
View file @
78fa316a
...
@@ -126,9 +126,6 @@ def create_dataloader(
...
@@ -126,9 +126,6 @@ def create_dataloader(
shuffle
=
shuffle
,
shuffle
=
shuffle
,
drop_uneven_inputs
=
drop_uneven_inputs
,
drop_uneven_inputs
=
drop_uneven_inputs
,
)
)
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
args
.
fanout
)
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
=
[
"feat"
])
############################################################################
############################################################################
# [Note]:
# [Note]:
# datapipe.copy_to() / gb.CopyTo()
# datapipe.copy_to() / gb.CopyTo()
...
@@ -137,8 +134,14 @@ def create_dataloader(
...
@@ -137,8 +134,14 @@ def create_dataloader(
# [Output]:
# [Output]:
# A CopyTo object copying data in the datapipe to a specified device.\
# A CopyTo object copying data in the datapipe to a specified device.\
############################################################################
############################################################################
if
not
args
.
cpu_sampling
:
datapipe
=
datapipe
.
copy_to
(
device
,
extra_attrs
=
[
"seed_nodes"
])
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
args
.
fanout
)
datapipe
=
datapipe
.
fetch_feature
(
features
,
node_feature_keys
=
[
"feat"
])
if
args
.
cpu_sampling
:
datapipe
=
datapipe
.
copy_to
(
device
)
datapipe
=
datapipe
.
copy_to
(
device
)
dataloader
=
gb
.
DataLoader
(
datapipe
,
num_workers
=
args
.
num_workers
)
dataloader
=
gb
.
DataLoader
(
datapipe
,
args
.
num_workers
)
# Return the fully-initialized DataLoader object.
# Return the fully-initialized DataLoader object.
return
dataloader
return
dataloader
...
@@ -272,15 +275,18 @@ def run(rank, world_size, args, devices, dataset):
...
@@ -272,15 +275,18 @@ def run(rank, world_size, args, devices, dataset):
rank
=
rank
,
rank
=
rank
,
)
)
graph
=
dataset
.
graph
# Pin the graph and features to enable GPU access.
features
=
dataset
.
feature
if
not
args
.
cpu_sampling
:
dataset
.
graph
.
pin_memory_
()
dataset
.
feature
.
pin_memory_
()
train_set
=
dataset
.
tasks
[
0
].
train_set
train_set
=
dataset
.
tasks
[
0
].
train_set
valid_set
=
dataset
.
tasks
[
0
].
validation_set
valid_set
=
dataset
.
tasks
[
0
].
validation_set
test_set
=
dataset
.
tasks
[
0
].
test_set
test_set
=
dataset
.
tasks
[
0
].
test_set
args
.
fanout
=
list
(
map
(
int
,
args
.
fanout
.
split
(
","
)))
args
.
fanout
=
list
(
map
(
int
,
args
.
fanout
.
split
(
","
)))
num_classes
=
dataset
.
tasks
[
0
].
metadata
[
"num_classes"
]
num_classes
=
dataset
.
tasks
[
0
].
metadata
[
"num_classes"
]
in_size
=
feature
s
.
size
(
"node"
,
None
,
"feat"
)[
0
]
in_size
=
dataset
.
feature
.
size
(
"node"
,
None
,
"feat"
)[
0
]
hidden_size
=
256
hidden_size
=
256
out_size
=
num_classes
out_size
=
num_classes
...
@@ -291,8 +297,8 @@ def run(rank, world_size, args, devices, dataset):
...
@@ -291,8 +297,8 @@ def run(rank, world_size, args, devices, dataset):
# Create data loaders.
# Create data loaders.
train_dataloader
=
create_dataloader
(
train_dataloader
=
create_dataloader
(
args
,
args
,
graph
,
dataset
.
graph
,
feature
s
,
dataset
.
feature
,
train_set
,
train_set
,
device
,
device
,
drop_last
=
False
,
drop_last
=
False
,
...
@@ -301,8 +307,8 @@ def run(rank, world_size, args, devices, dataset):
...
@@ -301,8 +307,8 @@ def run(rank, world_size, args, devices, dataset):
)
)
valid_dataloader
=
create_dataloader
(
valid_dataloader
=
create_dataloader
(
args
,
args
,
graph
,
dataset
.
graph
,
feature
s
,
dataset
.
feature
,
valid_set
,
valid_set
,
device
,
device
,
drop_last
=
False
,
drop_last
=
False
,
...
@@ -311,8 +317,8 @@ def run(rank, world_size, args, devices, dataset):
...
@@ -311,8 +317,8 @@ def run(rank, world_size, args, devices, dataset):
)
)
test_dataloader
=
create_dataloader
(
test_dataloader
=
create_dataloader
(
args
,
args
,
graph
,
dataset
.
graph
,
feature
s
,
dataset
.
feature
,
test_set
,
test_set
,
device
,
device
,
drop_last
=
False
,
drop_last
=
False
,
...
@@ -387,6 +393,11 @@ def parse_args():
...
@@ -387,6 +393,11 @@ def parse_args():
parser
.
add_argument
(
parser
.
add_argument
(
"--num-workers"
,
type
=
int
,
default
=
0
,
help
=
"The number of processes."
"--num-workers"
,
type
=
int
,
default
=
0
,
help
=
"The number of processes."
)
)
parser
.
add_argument
(
"--cpu-sampling"
,
action
=
"store_true"
,
help
=
"Disables GPU sampling and utilizes the CPU for dataloading."
,
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment