Commit b8ee6962 authored by rusty1s's avatar rusty1s
Browse files

clean up loader

parent dc5f7414
from typing import Optional, Union, Tuple, NamedTuple, List from typing import NamedTuple, List, Tuple
import time import time
...@@ -12,11 +12,11 @@ relabel_fn = torch.ops.torch_geometric_autoscale.relabel_one_hop ...@@ -12,11 +12,11 @@ relabel_fn = torch.ops.torch_geometric_autoscale.relabel_one_hop
class SubData(NamedTuple): class SubData(NamedTuple):
data: Union[Data, SparseTensor] data: Data
batch_size: int batch_size: int
n_id: Tensor n_id: Tensor
offset: Optional[Tensor] offset: Tensor
count: Optional[Tensor] count: Tensor
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
return SubData(self.data.to(*args, **kwargs), self.batch_size, return SubData(self.data.to(*args, **kwargs), self.batch_size,
...@@ -24,138 +24,101 @@ class SubData(NamedTuple): ...@@ -24,138 +24,101 @@ class SubData(NamedTuple):
class SubgraphLoader(DataLoader): class SubgraphLoader(DataLoader):
r"""A simple subgraph loader that, given a randomly sampled or r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
pre-partioned batch of nodes, returns the subgraph of this batch generates subgraphs (including its 1-hop neighbors) from mini-batches in
(including its 1-hop neighbors).""" :obj:`ptr`."""
def __init__( def __init__(self, data: Data, ptr: Tensor, batch_size: int = 1,
self, bipartite: bool = True, log: bool = True, **kwargs):
data: Union[Data, SparseTensor],
ptr: Optional[Tensor] = None, self.data = data
batch_size: int = 1, self.ptr = ptr
bipartite: bool = True, self.bipartite = bipartite
log: bool = True, self.log = log
**kwargs,
): n_id = torch.arange(data.num_nodes)
self.__data__ = None if isinstance(data, SparseTensor) else data
self.__adj_t__ = data if isinstance(data, SparseTensor) else data.adj_t
self.__N__ = self.__adj_t__.size(1)
self.__E__ = self.__adj_t__.nnz()
self.__ptr__ = ptr
self.__bipartite__ = bipartite
if ptr is not None:
n_id = torch.arange(self.__N__)
batches = n_id.split((ptr[1:] - ptr[:-1]).tolist()) batches = n_id.split((ptr[1:] - ptr[:-1]).tolist())
batches = [(i, batches[i]) for i in range(len(batches))] batches = [(i, batches[i]) for i in range(len(batches))]
if batch_size > 1: if batch_size > 1:
super(SubgraphLoader, super(SubgraphLoader, self).__init__(
self).__init__(batches, batches,
collate_fn=self.sample_partitions, batch_size=batch_size,
batch_size=batch_size, **kwargs) collate_fn=self.compute_subgraph,
else: **kwargs,
)
else: # If `batch_size=1`, we pre-process the subgraph generation:
if log: if log:
t = time.perf_counter() t = time.perf_counter()
print('Pre-processing subgraphs...', end=' ', flush=True) print('Pre-processing subgraphs...', end=' ', flush=True)
data_list = [ data_list = list(
data for data in DataLoader( DataLoader(batches, collate_fn=self.compute_subgraph,
batches, collate_fn=self.sample_partitions, batch_size=batch_size, **kwargs))
batch_size=batch_size, **kwargs)
]
if log: if log:
print(f'Done! [{time.perf_counter() - t:.2f}s]') print(f'Done! [{time.perf_counter() - t:.2f}s]')
super(SubgraphLoader, super(SubgraphLoader, self).__init__(
self).__init__(data_list, batch_size=1, data_list,
collate_fn=lambda x: x[0], **kwargs) batch_size=batch_size,
collate_fn=lambda x: x[0],
else: **kwargs,
super(SubgraphLoader, )
self).__init__(range(self.__N__),
collate_fn=self.sample_nodes,
batch_size=batch_size, **kwargs)
def sample_partitions(self, batches: List[Tuple[int, Tensor]]) -> SubData:
ptr_ids, n_ids = zip(*batches)
def compute_subgraph(self, batches: List[Tuple[int, Tensor]]) -> SubData:
batch_ids, n_ids = zip(*batches)
n_id = torch.cat(n_ids, dim=0) n_id = torch.cat(n_ids, dim=0)
batch_size = n_id.numel() batch_id = torch.tensor(batch_ids)
ptr_id = torch.tensor(ptr_ids)
offset = self.__ptr__[ptr_id]
count = self.__ptr__[ptr_id.add_(1)].sub_(offset)
rowptr, col, value = self.__adj_t__.csr() # We collect the in-mini-batch size (`batch_size`), the offset of each
rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id, # partition in the mini-batch (`offset`), and the number of nodes in
self.__bipartite__) # each partition (`count`)
adj_t = SparseTensor(rowptr=rowptr, col=col, value=value,
sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
is_sorted=True)
if self.__data__ is None:
return SubData(adj_t, batch_size, n_id, offset, count)
data = self.__data__.__class__(adj_t=adj_t)
for key, item in self.__data__:
if isinstance(item, Tensor) and item.size(0) == self.__N__:
data[key] = item.index_select(0, n_id)
elif isinstance(item, SparseTensor):
pass
else:
data[key] = item
return SubData(data, batch_size, n_id, offset, count)
def sample_nodes(self, n_ids: List[int]) -> SubData:
n_id = torch.tensor(n_ids)
batch_size = n_id.numel() batch_size = n_id.numel()
offset = self.ptr[batch_id]
count = self.ptr[batch_id.add_(1)].sub_(offset)
rowptr, col, value = self.__adj_t__.csr() rowptr, col, value = self.data.adj_t.csr()
rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id, rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id,
self.__bipartite__) self.bipartite)
adj_t = SparseTensor(rowptr=rowptr, col=col, value=value, adj_t = SparseTensor(rowptr=rowptr, col=col, value=value,
sparse_sizes=(rowptr.numel() - 1, n_id.numel()), sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
is_sorted=True) is_sorted=True)
if self.__data__ is None: data = self.data.__class__(adj_t=adj_t)
return SubData(adj_t, batch_size, n_id, None, None) for k, v in self.data:
if isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
data = self.__data__.__class__(adj_t=adj_t) data[k] = v.index_select(0, n_id)
for key, item in self.__data__:
if isinstance(item, Tensor) and item.size(0) == self.__N__:
data[key] = item.index_select(0, n_id)
elif isinstance(item, SparseTensor):
pass
else:
data[key] = item
return SubData(data, batch_size, n_id, None, None) return SubData(data, batch_size, n_id, offset, count)
def __repr__(self): def __repr__(self):
return f'{self.__class__.__name__}()' return f'{self.__class__.__name__}()'
class EvalSubgraphLoader(SubgraphLoader): class EvalSubgraphLoader(SubgraphLoader):
def __init__( r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
self, generates subgraphs (including its 1-hop neighbors) from mini-batches in
data: Union[Data, SparseTensor], :obj:`ptr`.
ptr: Optional[Tensor] = None, In contrast to :class:`SubgraphLoader`, this loader does not generate
batch_size: int = 1, subgraphs from randomly sampled mini-batches, and should therefore only be
bipartite: bool = True, used for evaluation.
log: bool = True, """
**kwargs, def __init__(self, data: Data, ptr: Tensor, batch_size: int = 1,
): bipartite: bool = True, log: bool = True, **kwargs):
num_nodes = ptr[-1]
ptr = ptr[::batch_size] ptr = ptr[::batch_size]
if int(ptr[-1]) != int(num_nodes): if int(ptr[-1]) != data.num_nodes:
ptr = torch.cat([ptr, num_nodes.unsqueeze(0)], dim=0) ptr = torch.cat([ptr, torch.tensor(data.num_nodes)], dim=0)
super(EvalSubgraphLoader, super(EvalSubgraphLoader, self).__init__(
self).__init__(data, ptr, 1, bipartite, log, num_workers=0, data=data,
shuffle=False, **kwargs) ptr=ptr,
batch_size=1,
bipartite=bipartite,
log=log,
shuffle=False,
num_workers=0,
**kwargs,
)
...@@ -88,8 +88,7 @@ class ScalableGNN(torch.nn.Module): ...@@ -88,8 +88,7 @@ class ScalableGNN(torch.nn.Module):
for hist in self.histories: for hist in self.histories:
self.pool.async_pull(hist.emb, None, None, n_id[batch_size:]) self.pool.async_pull(hist.emb, None, None, n_id[batch_size:])
out = self.forward(x=x, adj_t=adj_t, batch_size=batch_size, n_id=n_id, out = self.forward(x, adj_t, batch_size, n_id, offset, count, **kwargs)
offset=offset, count=count, **kwargs)
if self._async: if self._async:
for hist in self.histories: for hist in self.histories:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment