Commit 727f3279 authored by rusty1s's avatar rusty1s
Browse files

clean up history

parent b8ee6962
...@@ -5,10 +5,9 @@ from torch import Tensor ...@@ -5,10 +5,9 @@ from torch import Tensor
class History(torch.nn.Module): class History(torch.nn.Module):
r"""A node embedding storage module with asynchronous I/O support between r"""A historical embedding storage module."""
devices."""
def __init__(self, num_embeddings: int, embedding_dim: int, device=None): def __init__(self, num_embeddings: int, embedding_dim: int, device=None):
super(History, self).__init__() super().__init__()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
...@@ -25,37 +24,38 @@ class History(torch.nn.Module): ...@@ -25,37 +24,38 @@ class History(torch.nn.Module):
self.emb.fill_(0) self.emb.fill_(0)
def _apply(self, fn): def _apply(self, fn):
# Set the `_device` of the module without transfering `self.emb`.
self._device = fn(torch.zeros(1)).device self._device = fn(torch.zeros(1)).device
return self return self
@torch.no_grad() @torch.no_grad()
def pull(self, index: Optional[Tensor] = None) -> Tensor: def pull(self, n_id: Optional[Tensor] = None) -> Tensor:
out = self.emb out = self.emb
if index is not None: if n_id is not None:
assert index.device == self.emb.device assert n_id.device == self.emb.device
out = out.index_select(0, index) out = out.index_select(0, n_id)
return out.to(device=self._device) return out.to(device=self._device)
@torch.no_grad() @torch.no_grad()
def push(self, x, index: Optional[Tensor] = None, def push(self, x, n_id: Optional[Tensor] = None,
offset: Optional[Tensor] = None, count: Optional[Tensor] = None): offset: Optional[Tensor] = None, count: Optional[Tensor] = None):
if index is None and x.size(0) != self.num_embeddings: if n_id is None and x.size(0) != self.num_embeddings:
raise ValueError raise ValueError
elif index is None and x.size(0) == self.num_embeddings: elif n_id is None and x.size(0) == self.num_embeddings:
self.emb.copy_(x) self.emb.copy_(x)
elif index is not None and (offset is None or count is None): elif offset is None or count is None:
assert index.device == self.emb.device assert n_id.device == self.emb.device
self.emb[index] = x.to(self.emb.device) self.emb[n_id] = x.to(self.emb.device)
else: else: # Push in chunks:
x_o = 0 src_o = 0
x = x.to(self.emb.device) x = x.to(self.emb.device)
for o, c, in zip(offset.tolist(), count.tolist()): for dst_o, c, in zip(offset.tolist(), count.tolist()):
self.emb[o:o + c] = x[x_o:x_o + c] self.emb[dst_o:dst_o + c] = x[src_o:src_o + c]
x_o += c src_o += c
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
"""""" """"""
......
...@@ -40,12 +40,8 @@ class SubgraphLoader(DataLoader): ...@@ -40,12 +40,8 @@ class SubgraphLoader(DataLoader):
batches = [(i, batches[i]) for i in range(len(batches))] batches = [(i, batches[i]) for i in range(len(batches))]
if batch_size > 1: if batch_size > 1:
super(SubgraphLoader, self).__init__( super().__init__(batches, batch_size=batch_size,
batches, collate_fn=self.compute_subgraph, **kwargs)
batch_size=batch_size,
collate_fn=self.compute_subgraph,
**kwargs,
)
else: # If `batch_size=1`, we pre-process the subgraph generation: else: # If `batch_size=1`, we pre-process the subgraph generation:
if log: if log:
...@@ -59,12 +55,8 @@ class SubgraphLoader(DataLoader): ...@@ -59,12 +55,8 @@ class SubgraphLoader(DataLoader):
if log: if log:
print(f'Done! [{time.perf_counter() - t:.2f}s]') print(f'Done! [{time.perf_counter() - t:.2f}s]')
super(SubgraphLoader, self).__init__( super().__init__(data_list, batch_size=batch_size,
data_list, collate_fn=lambda x: x[0], **kwargs)
batch_size=batch_size,
collate_fn=lambda x: x[0],
**kwargs,
)
def compute_subgraph(self, batches: List[Tuple[int, Tensor]]) -> SubData: def compute_subgraph(self, batches: List[Tuple[int, Tensor]]) -> SubData:
batch_ids, n_ids = zip(*batches) batch_ids, n_ids = zip(*batches)
...@@ -112,13 +104,5 @@ class EvalSubgraphLoader(SubgraphLoader): ...@@ -112,13 +104,5 @@ class EvalSubgraphLoader(SubgraphLoader):
if int(ptr[-1]) != data.num_nodes: if int(ptr[-1]) != data.num_nodes:
ptr = torch.cat([ptr, torch.tensor(data.num_nodes)], dim=0) ptr = torch.cat([ptr, torch.tensor(data.num_nodes)], dim=0)
super(EvalSubgraphLoader, self).__init__( super().__init__(data=data, ptr=ptr, batch_size=1, bipartite=bipartite,
data=data, log=log, shuffle=False, num_workers=0, **kwargs)
ptr=ptr,
batch_size=1,
bipartite=bipartite,
log=log,
shuffle=False,
num_workers=0,
**kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment