Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
f5b04558
"vscode:/vscode.git/clone" did not exist on "015acfd2d868852d903ea03824ce7b308a556fcf"
Unverified
Commit
f5b04558
authored
Oct 16, 2023
by
peizhou001
Committed by
GitHub
Oct 16, 2023
Browse files
[Graphbolt] Fix fanout order (#6447)
parent
c8ec9ce3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
5 deletions
+12
-5
examples/sampling/graphbolt/link_prediction.py
examples/sampling/graphbolt/link_prediction.py
+7
-4
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+5
-1
No files found.
examples/sampling/graphbolt/link_prediction.py
View file @
f5b04558
...
...
@@ -301,9 +301,9 @@ def train(args, graph, features, train_set, valid_set, model):
break
# Evaluate the model.
print
(
"Validation"
)
valid_mrr
=
evaluate
(
args
,
graph
,
features
,
valid_set
,
model
)
print
(
f
"Valid MRR
{
valid_mrr
.
item
():.
4
f
}
"
)
#
print("Validation")
#
valid_mrr = evaluate(args, graph, features, valid_set, model)
#
print(f"Valid MRR {valid_mrr.item():.4f}")
def
parse_args
():
...
...
@@ -354,8 +354,11 @@ def main(args):
# Model training.
print
(
"Training..."
)
train
(
args
,
graph
,
features
,
train_set
,
valid_set
,
model
)
import
time
s
=
time
.
perf_counter
()
train
(
args
,
graph
,
features
,
train_set
,
valid_set
,
model
)
print
(
f
"
{
time
.
perf_counter
()
-
s
}
seconds elpased. "
)
# Test the model.
print
(
"Testing..."
)
test_set
=
dataset
.
tasks
[
0
].
test_set
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
f5b04558
...
...
@@ -34,6 +34,10 @@ class NeighborSampler(SubgraphSampler):
The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted.
Note: The fanout order is from the outermost layer to innermost layer.
For example, the fanout '[15, 10, 5]' means that 15 to the outermost
layer, 10 to the intermediate layer and 5 corresponds to the innermost
layer.
replace: bool
Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple
...
...
@@ -90,7 +94,7 @@ class NeighborSampler(SubgraphSampler):
for
fanout
in
fanouts
:
if
not
isinstance
(
fanout
,
torch
.
Tensor
):
fanout
=
torch
.
LongTensor
([
int
(
fanout
)])
self
.
fanouts
.
append
(
fanout
)
self
.
fanouts
.
insert
(
0
,
fanout
)
self
.
replace
=
replace
self
.
prob_name
=
prob_name
self
.
deduplicate
=
deduplicate
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment