Unverified Commit 3885da25 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `_init_all_nodes_set` to use `seeds`. (#7243)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 2f585940
...@@ -101,7 +101,7 @@ class SAGE(nn.Module): ...@@ -101,7 +101,7 @@ class SAGE(nn.Module):
if not is_last_layer: if not is_last_layer:
hidden_x = F.relu(hidden_x) hidden_x = F.relu(hidden_x)
# By design, our seed nodes are contiguous. # By design, our seed nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to( y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(
buffer_device, non_blocking=True buffer_device, non_blocking=True
) )
if not is_last_layer: if not is_last_layer:
......
...@@ -218,7 +218,7 @@ class SAGE(nn.Module): ...@@ -218,7 +218,7 @@ class SAGE(nn.Module):
hidden_x = F.relu(hidden_x) hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x) hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous. # By design, our output nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to( y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(
buffer_device buffer_device
) )
if not is_last_layer: if not is_last_layer:
......
...@@ -138,7 +138,7 @@ class GraphSAGE(torch.nn.Module): ...@@ -138,7 +138,7 @@ class GraphSAGE(torch.nn.Module):
if not is_last_layer: if not is_last_layer:
hidden_x = F.relu(hidden_x) hidden_x = F.relu(hidden_x)
# By design, our output nodes are contiguous. # By design, our output nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to( y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(
buffer_device buffer_device
) )
if not is_last_layer: if not is_last_layer:
......
...@@ -900,13 +900,13 @@ class OnDiskDataset(Dataset): ...@@ -900,13 +900,13 @@ class OnDiskDataset(Dataset):
if isinstance(num_nodes, int): if isinstance(num_nodes, int):
return ItemSet( return ItemSet(
torch.tensor(num_nodes, dtype=dtype), torch.tensor(num_nodes, dtype=dtype),
names="seed_nodes", names="seeds",
) )
else: else:
data = { data = {
node_type: ItemSet( node_type: ItemSet(
torch.tensor(num_node, dtype=dtype), torch.tensor(num_node, dtype=dtype),
names="seed_nodes", names="seeds",
) )
for node_type, num_node in num_nodes.items() for node_type, num_node in num_nodes.items()
} }
......
...@@ -198,21 +198,21 @@ def random_homo_graphbolt_graph( ...@@ -198,21 +198,21 @@ def random_homo_graphbolt_graph(
train_set: train_set:
- type: null - type: null
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_path} path: {train_path}
validation_set: validation_set:
- type: null - type: null
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {validation_path} path: {validation_path}
test_set: test_set:
- type: null - type: null
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_path} path: {test_path}
...@@ -349,21 +349,21 @@ def generate_raw_data_for_hetero_dataset( ...@@ -349,21 +349,21 @@ def generate_raw_data_for_hetero_dataset(
train_set: train_set:
- type: user - type: user
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_path} path: {train_path}
validation_set: validation_set:
- type: user - type: user
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {validation_path} path: {validation_path}
test_set: test_set:
- type: user - type: user
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_path} path: {test_path}
......
...@@ -2326,7 +2326,7 @@ def test_OnDiskDataset_all_nodes_set_homo(): ...@@ -2326,7 +2326,7 @@ def test_OnDiskDataset_all_nodes_set_homo():
dataset = write_yaml_and_load_dataset(yaml_content, test_dir) dataset = write_yaml_and_load_dataset(yaml_content, test_dir)
all_nodes_set = dataset.all_nodes_set all_nodes_set = dataset.all_nodes_set
assert isinstance(all_nodes_set, gb.ItemSet) assert isinstance(all_nodes_set, gb.ItemSet)
assert all_nodes_set.names == ("seed_nodes",) assert all_nodes_set.names == ("seeds",)
for i, item in enumerate(all_nodes_set): for i, item in enumerate(all_nodes_set):
assert i == item assert i == item
...@@ -2365,7 +2365,7 @@ def test_OnDiskDataset_all_nodes_set_hetero(): ...@@ -2365,7 +2365,7 @@ def test_OnDiskDataset_all_nodes_set_hetero():
dataset = write_yaml_and_load_dataset(yaml_content, test_dir) dataset = write_yaml_and_load_dataset(yaml_content, test_dir)
all_nodes_set = dataset.all_nodes_set all_nodes_set = dataset.all_nodes_set
assert isinstance(all_nodes_set, gb.ItemSetDict) assert isinstance(all_nodes_set, gb.ItemSetDict)
assert all_nodes_set.names == ("seed_nodes",) assert all_nodes_set.names == ("seeds",)
for i, item in enumerate(all_nodes_set): for i, item in enumerate(all_nodes_set):
assert len(item) == 1 assert len(item) == 1
assert isinstance(item, dict) assert isinstance(item, dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment