Commit b8ee6962 authored by rusty1s's avatar rusty1s
Browse files

clean up loader

parent dc5f7414
from typing import Optional, Union, Tuple, NamedTuple, List
from typing import NamedTuple, List, Tuple
import time
......@@ -12,11 +12,11 @@ relabel_fn = torch.ops.torch_geometric_autoscale.relabel_one_hop
class SubData(NamedTuple):
data: Union[Data, SparseTensor]
data: Data
batch_size: int
n_id: Tensor
offset: Optional[Tensor]
count: Optional[Tensor]
offset: Tensor
count: Tensor
def to(self, *args, **kwargs):
return SubData(self.data.to(*args, **kwargs), self.batch_size,
......@@ -24,138 +24,101 @@ class SubData(NamedTuple):
class SubgraphLoader(DataLoader):
r"""A simple subgraph loader that, given a randomly sampled or
pre-partioned batch of nodes, returns the subgraph of this batch
(including its 1-hop neighbors)."""
def __init__(
self,
data: Union[Data, SparseTensor],
ptr: Optional[Tensor] = None,
batch_size: int = 1,
bipartite: bool = True,
log: bool = True,
**kwargs,
):
self.__data__ = None if isinstance(data, SparseTensor) else data
self.__adj_t__ = data if isinstance(data, SparseTensor) else data.adj_t
self.__N__ = self.__adj_t__.size(1)
self.__E__ = self.__adj_t__.nnz()
self.__ptr__ = ptr
self.__bipartite__ = bipartite
if ptr is not None:
n_id = torch.arange(self.__N__)
batches = n_id.split((ptr[1:] - ptr[:-1]).tolist())
batches = [(i, batches[i]) for i in range(len(batches))]
if batch_size > 1:
super(SubgraphLoader,
self).__init__(batches,
collate_fn=self.sample_partitions,
batch_size=batch_size, **kwargs)
else:
if log:
t = time.perf_counter()
print('Pre-processing subgraphs...', end=' ', flush=True)
data_list = [
data for data in DataLoader(
batches, collate_fn=self.sample_partitions,
batch_size=batch_size, **kwargs)
]
if log:
print(f'Done! [{time.perf_counter() - t:.2f}s]')
super(SubgraphLoader,
self).__init__(data_list, batch_size=1,
collate_fn=lambda x: x[0], **kwargs)
else:
super(SubgraphLoader,
self).__init__(range(self.__N__),
collate_fn=self.sample_nodes,
batch_size=batch_size, **kwargs)
def sample_partitions(self, batches: List[Tuple[int, Tensor]]) -> SubData:
ptr_ids, n_ids = zip(*batches)
r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
generates subgraphs (including its 1-hop neighbors) from mini-batches in
:obj:`ptr`."""
def __init__(self, data: Data, ptr: Tensor, batch_size: int = 1,
bipartite: bool = True, log: bool = True, **kwargs):
self.data = data
self.ptr = ptr
self.bipartite = bipartite
self.log = log
n_id = torch.arange(data.num_nodes)
batches = n_id.split((ptr[1:] - ptr[:-1]).tolist())
batches = [(i, batches[i]) for i in range(len(batches))]
if batch_size > 1:
super(SubgraphLoader, self).__init__(
batches,
batch_size=batch_size,
collate_fn=self.compute_subgraph,
**kwargs,
)
else: # If `batch_size=1`, we pre-process the subgraph generation:
if log:
t = time.perf_counter()
print('Pre-processing subgraphs...', end=' ', flush=True)
data_list = list(
DataLoader(batches, collate_fn=self.compute_subgraph,
batch_size=batch_size, **kwargs))
if log:
print(f'Done! [{time.perf_counter() - t:.2f}s]')
super(SubgraphLoader, self).__init__(
data_list,
batch_size=batch_size,
collate_fn=lambda x: x[0],
**kwargs,
)
def compute_subgraph(self, batches: List[Tuple[int, Tensor]]) -> SubData:
batch_ids, n_ids = zip(*batches)
n_id = torch.cat(n_ids, dim=0)
batch_size = n_id.numel()
ptr_id = torch.tensor(ptr_ids)
offset = self.__ptr__[ptr_id]
count = self.__ptr__[ptr_id.add_(1)].sub_(offset)
rowptr, col, value = self.__adj_t__.csr()
rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id,
self.__bipartite__)
adj_t = SparseTensor(rowptr=rowptr, col=col, value=value,
sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
is_sorted=True)
batch_id = torch.tensor(batch_ids)
if self.__data__ is None:
return SubData(adj_t, batch_size, n_id, offset, count)
data = self.__data__.__class__(adj_t=adj_t)
for key, item in self.__data__:
if isinstance(item, Tensor) and item.size(0) == self.__N__:
data[key] = item.index_select(0, n_id)
elif isinstance(item, SparseTensor):
pass
else:
data[key] = item
return SubData(data, batch_size, n_id, offset, count)
def sample_nodes(self, n_ids: List[int]) -> SubData:
n_id = torch.tensor(n_ids)
# We collect the in-mini-batch size (`batch_size`), the offset of each
# partition in the mini-batch (`offset`), and the number of nodes in
# each partition (`count`)
batch_size = n_id.numel()
offset = self.ptr[batch_id]
count = self.ptr[batch_id.add_(1)].sub_(offset)
rowptr, col, value = self.__adj_t__.csr()
rowptr, col, value = self.data.adj_t.csr()
rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id,
self.__bipartite__)
self.bipartite)
adj_t = SparseTensor(rowptr=rowptr, col=col, value=value,
sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
is_sorted=True)
if self.__data__ is None:
return SubData(adj_t, batch_size, n_id, None, None)
data = self.data.__class__(adj_t=adj_t)
for k, v in self.data:
if isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
data[k] = v.index_select(0, n_id)
data = self.__data__.__class__(adj_t=adj_t)
for key, item in self.__data__:
if isinstance(item, Tensor) and item.size(0) == self.__N__:
data[key] = item.index_select(0, n_id)
elif isinstance(item, SparseTensor):
pass
else:
data[key] = item
return SubData(data, batch_size, n_id, None, None)
return SubData(data, batch_size, n_id, offset, count)
def __repr__(self):
return f'{self.__class__.__name__}()'
class EvalSubgraphLoader(SubgraphLoader):
def __init__(
self,
data: Union[Data, SparseTensor],
ptr: Optional[Tensor] = None,
batch_size: int = 1,
bipartite: bool = True,
log: bool = True,
**kwargs,
):
num_nodes = ptr[-1]
ptr = ptr[::batch_size]
if int(ptr[-1]) != int(num_nodes):
ptr = torch.cat([ptr, num_nodes.unsqueeze(0)], dim=0)
r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
generates subgraphs (including its 1-hop neighbors) from mini-batches in
:obj:`ptr`.
In contrast to :class:`SubgraphLoader`, this loader does not generate
subgraphs from randomly sampled mini-batches, and should therefore only be
used for evaluation.
"""
def __init__(self, data: Data, ptr: Tensor, batch_size: int = 1,
bipartite: bool = True, log: bool = True, **kwargs):
super(EvalSubgraphLoader,
self).__init__(data, ptr, 1, bipartite, log, num_workers=0,
shuffle=False, **kwargs)
ptr = ptr[::batch_size]
if int(ptr[-1]) != data.num_nodes:
ptr = torch.cat([ptr, torch.tensor(data.num_nodes)], dim=0)
super(EvalSubgraphLoader, self).__init__(
data=data,
ptr=ptr,
batch_size=1,
bipartite=bipartite,
log=log,
shuffle=False,
num_workers=0,
**kwargs,
)
......@@ -88,8 +88,7 @@ class ScalableGNN(torch.nn.Module):
for hist in self.histories:
self.pool.async_pull(hist.emb, None, None, n_id[batch_size:])
out = self.forward(x=x, adj_t=adj_t, batch_size=batch_size, n_id=n_id,
offset=offset, count=count, **kwargs)
out = self.forward(x, adj_t, batch_size, n_id, offset, count, **kwargs)
if self._async:
for hist in self.histories:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment