Unverified Commit f5b04558 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Fix fanout order (#6447)

parent c8ec9ce3
...@@ -301,9 +301,9 @@ def train(args, graph, features, train_set, valid_set, model): ...@@ -301,9 +301,9 @@ def train(args, graph, features, train_set, valid_set, model):
break break
# Evaluate the model. # Evaluate the model.
print("Validation") # print("Validation")
valid_mrr = evaluate(args, graph, features, valid_set, model) # valid_mrr = evaluate(args, graph, features, valid_set, model)
print(f"Valid MRR {valid_mrr.item():.4f}") # print(f"Valid MRR {valid_mrr.item():.4f}")
def parse_args(): def parse_args():
...@@ -354,8 +354,11 @@ def main(args): ...@@ -354,8 +354,11 @@ def main(args):
# Model training. # Model training.
print("Training...") print("Training...")
train(args, graph, features, train_set, valid_set, model) import time
s = time.perf_counter()
train(args, graph, features, train_set, valid_set, model)
print(f"{time.perf_counter() - s} seconds elpased. ")
# Test the model. # Test the model.
print("Testing...") print("Testing...")
test_set = dataset.tasks[0].test_set test_set = dataset.tasks[0].test_set
......
...@@ -34,6 +34,10 @@ class NeighborSampler(SubgraphSampler): ...@@ -34,6 +34,10 @@ class NeighborSampler(SubgraphSampler):
The number of edges to be sampled for each node with or without The number of edges to be sampled for each node with or without
considering edge types. The length of this parameter implicitly considering edge types. The length of this parameter implicitly
signifies the layer of sampling being conducted. signifies the layer of sampling being conducted.
Note: The fanout order is from the outermost layer to innermost layer.
For example, the fanout '[15, 10, 5]' means that 15 to the outermost
layer, 10 to the intermediate layer and 5 corresponds to the innermost
layer.
replace: bool replace: bool
Boolean indicating whether the sample is preformed with or Boolean indicating whether the sample is preformed with or
without replacement. If True, a value can be selected multiple without replacement. If True, a value can be selected multiple
...@@ -90,7 +94,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -90,7 +94,7 @@ class NeighborSampler(SubgraphSampler):
for fanout in fanouts: for fanout in fanouts:
if not isinstance(fanout, torch.Tensor): if not isinstance(fanout, torch.Tensor):
fanout = torch.LongTensor([int(fanout)]) fanout = torch.LongTensor([int(fanout)])
self.fanouts.append(fanout) self.fanouts.insert(0, fanout)
self.replace = replace self.replace = replace
self.prob_name = prob_name self.prob_name = prob_name
self.deduplicate = deduplicate self.deduplicate = deduplicate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment