Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
6178897d
Unverified
Commit
6178897d
authored
Dec 26, 2023
by
yxy235
Committed by
GitHub
Dec 26, 2023
Browse files
[GraphBolt] Remove opt `output_cscformat` from examples. (#6834)
parent
3645e493
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
32 deletions
+3
-32
examples/sampling/graphbolt/link_prediction.py
examples/sampling/graphbolt/link_prediction.py
+1
-11
examples/sampling/graphbolt/node_classification.py
examples/sampling/graphbolt/node_classification.py
+2
-21
No files found.
examples/sampling/graphbolt/link_prediction.py
View file @
6178897d
...
@@ -175,11 +175,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
...
@@ -175,11 +175,7 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# [Role]:
# [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
############################################################################
datapipe
=
datapipe
.
sample_neighbor
(
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
args
.
fanout
)
graph
,
args
.
fanout
,
output_cscformat
=
(
args
.
output_cscformat
==
"True"
),
)
############################################################################
############################################################################
# [Input]:
# [Input]:
...
@@ -375,12 +371,6 @@ def parse_args():
...
@@ -375,12 +371,6 @@ def parse_args():
choices
=
[
"cpu"
,
"cuda"
],
choices
=
[
"cpu"
,
"cuda"
],
help
=
"Train device: 'cpu' for CPU, 'cuda' for GPU."
,
help
=
"Train device: 'cpu' for CPU, 'cuda' for GPU."
,
)
)
parser
.
add_argument
(
"--output_cscformat"
,
default
=
"False"
,
choices
=
[
"False"
,
"True"
],
help
=
"Output type of SampledSubgraph. True for csc_formats, False for node_pairs."
,
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
...
examples/sampling/graphbolt/node_classification.py
View file @
6178897d
...
@@ -50,15 +50,7 @@ from tqdm import tqdm
...
@@ -50,15 +50,7 @@ from tqdm import tqdm
def
create_dataloader
(
def
create_dataloader
(
graph
,
graph
,
features
,
itemset
,
batch_size
,
fanout
,
device
,
num_workers
,
job
features
,
itemset
,
batch_size
,
fanout
,
device
,
num_workers
,
job
,
output_cscformat
,
):
):
"""
"""
[HIGHLIGHT]
[HIGHLIGHT]
...
@@ -113,9 +105,7 @@ def create_dataloader(
...
@@ -113,9 +105,7 @@ def create_dataloader(
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
############################################################################
datapipe
=
datapipe
.
sample_neighbor
(
datapipe
=
datapipe
.
sample_neighbor
(
graph
,
graph
,
fanout
if
job
!=
"infer"
else
[
-
1
]
fanout
if
job
!=
"infer"
else
[
-
1
],
output_cscformat
=
(
output_cscformat
==
"True"
),
)
)
############################################################################
############################################################################
...
@@ -240,7 +230,6 @@ def layerwise_infer(
...
@@ -240,7 +230,6 @@ def layerwise_infer(
device
=
args
.
device
,
device
=
args
.
device
,
num_workers
=
args
.
num_workers
,
num_workers
=
args
.
num_workers
,
job
=
"infer"
,
job
=
"infer"
,
output_cscformat
=
args
.
output_cscformat
,
)
)
pred
=
model
.
inference
(
graph
,
features
,
dataloader
,
args
.
device
)
pred
=
model
.
inference
(
graph
,
features
,
dataloader
,
args
.
device
)
pred
=
pred
[
test_set
.
_items
[
0
]]
pred
=
pred
[
test_set
.
_items
[
0
]]
...
@@ -268,7 +257,6 @@ def evaluate(args, model, graph, features, itemset, num_classes):
...
@@ -268,7 +257,6 @@ def evaluate(args, model, graph, features, itemset, num_classes):
device
=
args
.
device
,
device
=
args
.
device
,
num_workers
=
args
.
num_workers
,
num_workers
=
args
.
num_workers
,
job
=
"evaluate"
,
job
=
"evaluate"
,
output_cscformat
=
args
.
output_cscformat
,
)
)
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
for
step
,
data
in
tqdm
(
enumerate
(
dataloader
)):
...
@@ -295,7 +283,6 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
...
@@ -295,7 +283,6 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
device
=
args
.
device
,
device
=
args
.
device
,
num_workers
=
args
.
num_workers
,
num_workers
=
args
.
num_workers
,
job
=
"train"
,
job
=
"train"
,
output_cscformat
=
args
.
output_cscformat
,
)
)
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
...
@@ -367,12 +354,6 @@ def parse_args():
...
@@ -367,12 +354,6 @@ def parse_args():
choices
=
[
"cpu"
,
"cuda"
],
choices
=
[
"cpu"
,
"cuda"
],
help
=
"Train device: 'cpu' for CPU, 'cuda' for GPU."
,
help
=
"Train device: 'cpu' for CPU, 'cuda' for GPU."
,
)
)
parser
.
add_argument
(
"--output_cscformat"
,
default
=
"False"
,
choices
=
[
"False"
,
"True"
],
help
=
"Output type of SampledSubgraph. True for csc_formats, False for node_pairs."
,
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment