Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f5b04558
Unverified
Commit
f5b04558
authored
Oct 16, 2023
by
peizhou001
Committed by
GitHub
Oct 16, 2023
Browse files
[Graphbolt] Fix fanout order (#6447)
parent
c8ec9ce3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
5 deletions
+12
-5
examples/sampling/graphbolt/link_prediction.py
examples/sampling/graphbolt/link_prediction.py
+7
-4
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+5
-1
No files found.
examples/sampling/graphbolt/link_prediction.py
View file @
f5b04558
...
@@ -301,9 +301,9 @@ def train(args, graph, features, train_set, valid_set, model):
...
@@ -301,9 +301,9 @@ def train(args, graph, features, train_set, valid_set, model):
break
break
# Evaluate the model.
# Evaluate the model.
print
(
"Validation"
)
#
print("Validation")
valid_mrr
=
evaluate
(
args
,
graph
,
features
,
valid_set
,
model
)
#
valid_mrr = evaluate(args, graph, features, valid_set, model)
print
(
f
"Valid MRR
{
valid_mrr
.
item
():.
4
f
}
"
)
#
print(f"Valid MRR {valid_mrr.item():.4f}")
def
parse_args
():
def
parse_args
():
...
@@ -354,8 +354,11 @@ def main(args):
...
@@ -354,8 +354,11 @@ def main(args):
# Model training.
# Model training.
print
(
"Training..."
)
print
(
"Training..."
)
train
(
args
,
graph
,
features
,
train_set
,
valid_set
,
model
)
import
time
s
=
time
.
perf_counter
()
train
(
args
,
graph
,
features
,
train_set
,
valid_set
,
model
)
print
(
f
"
{
time
.
perf_counter
()
-
s
}
seconds elpased. "
)
# Test the model.
# Test the model.
print
(
"Testing..."
)
print
(
"Testing..."
)
test_set
=
dataset
.
tasks
[
0
].
test_set
test_set
=
dataset
.
tasks
[
0
].
test_set
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
f5b04558
...
@@ -34,6 +34,10 @@ class NeighborSampler(SubgraphSampler):
...
@@ -34,6 +34,10 @@ class NeighborSampler(SubgraphSampler):
The number of edges to be sampled for each node with or without
The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly
considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted.
signifies the layer of sampling being conducted.
Note: The fanout order is from the outermost layer to innermost layer.
For example, the fanout '[15, 10, 5]' means that 15 to the outermost
layer, 10 to the intermediate layer and 5 corresponds to the innermost
layer.
replace: bool
replace: bool
Boolean indicating whether the sample is preformed with or
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
without replacement. If True, a value can be selected multiple
...
@@ -90,7 +94,7 @@ class NeighborSampler(SubgraphSampler):
...
@@ -90,7 +94,7 @@ class NeighborSampler(SubgraphSampler):
for
fanout
in
fanouts
:
for
fanout
in
fanouts
:
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
fanout
=
torch
.
LongTensor
([
int
(
fanout
)])
fanout
=
torch
.
LongTensor
([
int
(
fanout
)])
self
.
fanouts
.
append
(
fanout
)
self
.
fanouts
.
insert
(
0
,
fanout
)
self
.
replace
=
replace
self
.
replace
=
replace
self
.
prob_name
=
prob_name
self
.
prob_name
=
prob_name
self
.
deduplicate
=
deduplicate
self
.
deduplicate
=
deduplicate
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment