Commit c427300e authored by rusty1s's avatar rusty1s
Browse files

clean up

parent d8d31882
...@@ -21,7 +21,8 @@ from torch_geometric_autoscale import ScalableGNN ...@@ -21,7 +21,8 @@ from torch_geometric_autoscale import ScalableGNN
class GNN(ScalableGNN): class GNN(ScalableGNN):
def __init__(self, num_nodes, in_channels, hidden_channels, out_channels, num_layers): def __init__(self, num_nodes, in_channels, hidden_channels, out_channels, num_layers):
super(GNN, self).__init__(num_nodes, hidden_channels, num_layers) super(GNN, self).__init__(num_nodes, hidden_channels, num_layers,
pool_size=2, buffer_size=5000)
self.convs = ModuleList() self.convs = ModuleList()
self.convs.append(GCNConv(in_channels, hidden_channels)) self.convs.append(GCNConv(in_channels, hidden_channels))
...@@ -29,11 +30,20 @@ class GNN(ScalableGNN): ...@@ -29,11 +30,20 @@ class GNN(ScalableGNN):
self.convs.append(GCNConv(hidden_channels, hidden_channels)) self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.convs.append(GCNConv(hidden_channels, out_channels)) self.convs.append(GCNConv(hidden_channels, out_channels))
def forward(self, x, adj_t, batch_size, n_id): def forward(self, x, adj_t, *args):
for conv, history in zip(self.convs[:-1], self.histories): for conv, history in zip(self.convs[:-1], self.histories):
x = conv(x, adj_t).relu_() x = conv(x, adj_t).relu_()
x = self.push_and_pull(history, x, batch_size, n_id) x = self.push_and_pull(history, x, *args)
return self.convs[-1](x, adj_t) return self.convs[-1](x, adj_t)
perm, ptr = metis(data.adj_t, num_parts=40, log=True)
data = permute(data, perm, log=True)
loader = SubgraphLoader(data, ptr, batch_size=10, shuffle=True)
for batch, *args in loader:
out = model(batch.x, batch.adj_t, *args)
``` ```
## Requirements ## Requirements
......
...@@ -110,24 +110,25 @@ def get_sbm(root: str, name: str) -> Tuple[Data, int, int]: ...@@ -110,24 +110,25 @@ def get_sbm(root: str, name: str) -> Tuple[Data, int, int]:
def get_data(root: str, name: str) -> Tuple[Data, int, int]: def get_data(root: str, name: str) -> Tuple[Data, int, int]:
if name.lower() in ['cora', 'citeseer', 'pubmed']: if name.lower() in ['cora', 'citeseer', 'pubmed']:
return get_planetoid(root, name) return get_planetoid(root, name)
if name.lower() in ['coauthorcs', 'coauthorphysics']: elif name.lower() in ['coauthorcs', 'coauthorphysics']:
return get_coauthor(root, name[8:]) return get_coauthor(root, name[8:])
if name.lower() in ['amazoncomputers', 'amazonphoto']: elif name.lower() in ['amazoncomputers', 'amazonphoto']:
return get_amazon(root, name[6:]) return get_amazon(root, name[6:])
if name.lower() == 'wikics': elif name.lower() == 'wikics':
return get_wikics(root) return get_wikics(root)
if name.lower() in ['cluster', 'pattern']: elif name.lower() in ['cluster', 'pattern']:
return get_sbm(root, name) return get_sbm(root, name)
if name.lower() == 'reddit': elif name.lower() == 'reddit':
return get_reddit(root) return get_reddit(root)
if name.lower() == 'ppi': elif name.lower() == 'ppi':
return get_ppi(root) return get_ppi(root)
if name.lower() == 'flickr': elif name.lower() == 'flickr':
return get_flickr(root) return get_flickr(root)
if name.lower() == 'yelp': elif name.lower() == 'yelp':
return get_yelp(root) return get_yelp(root)
if name.lower() in ['ogbn-arxiv', 'arxiv']: elif name.lower() in ['ogbn-arxiv', 'arxiv']:
return get_arxiv(root) return get_arxiv(root)
if name.lower() in ['ogbn-products', 'products']: elif name.lower() in ['ogbn-products', 'products']:
return get_products(root) return get_products(root)
else:
raise NotImplementedError raise NotImplementedError
...@@ -46,9 +46,9 @@ def permute(data: Union[Data, SparseTensor], perm: Tensor, ...@@ -46,9 +46,9 @@ def permute(data: Union[Data, SparseTensor], perm: Tensor,
for key, item in data: for key, item in data:
if isinstance(item, Tensor) and item.size(0) == data.num_nodes: if isinstance(item, Tensor) and item.size(0) == data.num_nodes:
data[key] = item[perm] data[key] = item[perm]
if isinstance(item, Tensor) and item.size(0) == data.num_edges: elif isinstance(item, Tensor) and item.size(0) == data.num_edges:
raise NotImplementedError raise NotImplementedError
if isinstance(item, SparseTensor): elif isinstance(item, SparseTensor):
data[key] = permute(item, perm, log=False) data[key] = permute(item, perm, log=False)
else: else:
data = data.permute(perm) data = data.permute(perm)
......
...@@ -18,7 +18,7 @@ class ScalableGNN(torch.nn.Module): ...@@ -18,7 +18,7 @@ class ScalableGNN(torch.nn.Module):
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.hidden_channels = hidden_channels self.hidden_channels = hidden_channels
self.num_layers = num_layers self.num_layers = num_layers
self.pool_size = num_layers if pool_size is None else pool_size self.pool_size = num_layers - 1 if pool_size is None else pool_size
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.histories = torch.nn.ModuleList([ self.histories = torch.nn.ModuleList([
...@@ -59,13 +59,17 @@ class ScalableGNN(torch.nn.Module): ...@@ -59,13 +59,17 @@ class ScalableGNN(torch.nn.Module):
for history in self.histories: for history in self.histories:
history.reset_parameters() history.reset_parameters()
def __call__(self, x: Optional[Tensor] = None, def __call__(
self,
x: Optional[Tensor] = None,
adj_t: Optional[SparseTensor] = None, adj_t: Optional[SparseTensor] = None,
batch_size: Optional[int] = None, batch_size: Optional[int] = None,
n_id: Optional[Tensor] = None, n_id: Optional[Tensor] = None,
offset: Optional[Tensor] = None, offset: Optional[Tensor] = None,
count: Optional[Tensor] = None, loader=None, count: Optional[Tensor] = None,
**kwargs) -> Tensor: loader=None,
**kwargs,
) -> Tensor:
if loader is not None: if loader is not None:
return self.mini_inference(loader) return self.mini_inference(loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment