Unverified Commit 3885da25 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `_init_all_nodes_set` to use `seeds`. (#7243)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 2f585940
......@@ -101,7 +101,7 @@ class SAGE(nn.Module):
if not is_last_layer:
hidden_x = F.relu(hidden_x)
# By design, our seed nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to(
y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(
buffer_device, non_blocking=True
)
if not is_last_layer:
......
......@@ -218,7 +218,7 @@ class SAGE(nn.Module):
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to(
y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(
buffer_device
)
if not is_last_layer:
......
......@@ -138,7 +138,7 @@ class GraphSAGE(torch.nn.Module):
if not is_last_layer:
hidden_x = F.relu(hidden_x)
# By design, our output nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to(
y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(
buffer_device
)
if not is_last_layer:
......
......@@ -900,13 +900,13 @@ class OnDiskDataset(Dataset):
if isinstance(num_nodes, int):
return ItemSet(
torch.tensor(num_nodes, dtype=dtype),
names="seed_nodes",
names="seeds",
)
else:
data = {
node_type: ItemSet(
torch.tensor(num_node, dtype=dtype),
names="seed_nodes",
names="seeds",
)
for node_type, num_node in num_nodes.items()
}
......
......@@ -198,21 +198,21 @@ def random_homo_graphbolt_graph(
train_set:
- type: null
data:
- name: node_pairs
- name: seeds
format: numpy
in_memory: true
path: {train_path}
validation_set:
- type: null
data:
- name: node_pairs
- name: seeds
format: numpy
in_memory: true
path: {validation_path}
test_set:
- type: null
data:
- name: node_pairs
- name: seeds
format: numpy
in_memory: true
path: {test_path}
......@@ -349,21 +349,21 @@ def generate_raw_data_for_hetero_dataset(
train_set:
- type: user
data:
- name: seed_nodes
- name: seeds
format: numpy
in_memory: true
path: {train_path}
validation_set:
- type: user
data:
- name: seed_nodes
- name: seeds
format: numpy
in_memory: true
path: {validation_path}
test_set:
- type: user
data:
- name: seed_nodes
- name: seeds
format: numpy
in_memory: true
path: {test_path}
......
......@@ -2326,7 +2326,7 @@ def test_OnDiskDataset_all_nodes_set_homo():
dataset = write_yaml_and_load_dataset(yaml_content, test_dir)
all_nodes_set = dataset.all_nodes_set
assert isinstance(all_nodes_set, gb.ItemSet)
assert all_nodes_set.names == ("seed_nodes",)
assert all_nodes_set.names == ("seeds",)
for i, item in enumerate(all_nodes_set):
assert i == item
......@@ -2365,7 +2365,7 @@ def test_OnDiskDataset_all_nodes_set_hetero():
dataset = write_yaml_and_load_dataset(yaml_content, test_dir)
all_nodes_set = dataset.all_nodes_set
assert isinstance(all_nodes_set, gb.ItemSetDict)
assert all_nodes_set.names == ("seed_nodes",)
assert all_nodes_set.names == ("seeds",)
for i, item in enumerate(all_nodes_set):
assert len(item) == 1
assert isinstance(item, dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment