Commit 727f3279 authored by rusty1s's avatar rusty1s
Browse files

clean up history

parent b8ee6962
......@@ -5,10 +5,9 @@ from torch import Tensor
class History(torch.nn.Module):
r"""A node embedding storage module with asynchronous I/O support between
devices."""
r"""A historical embedding storage module."""
def __init__(self, num_embeddings: int, embedding_dim: int, device=None):
super(History, self).__init__()
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
......@@ -25,37 +24,38 @@ class History(torch.nn.Module):
self.emb.fill_(0)
def _apply(self, fn):
# Set the `_device` of the module without transfering `self.emb`.
self._device = fn(torch.zeros(1)).device
return self
@torch.no_grad()
def pull(self, index: Optional[Tensor] = None) -> Tensor:
def pull(self, n_id: Optional[Tensor] = None) -> Tensor:
out = self.emb
if index is not None:
assert index.device == self.emb.device
out = out.index_select(0, index)
if n_id is not None:
assert n_id.device == self.emb.device
out = out.index_select(0, n_id)
return out.to(device=self._device)
@torch.no_grad()
def push(self, x, index: Optional[Tensor] = None,
def push(self, x, n_id: Optional[Tensor] = None,
offset: Optional[Tensor] = None, count: Optional[Tensor] = None):
if index is None and x.size(0) != self.num_embeddings:
if n_id is None and x.size(0) != self.num_embeddings:
raise ValueError
elif index is None and x.size(0) == self.num_embeddings:
elif n_id is None and x.size(0) == self.num_embeddings:
self.emb.copy_(x)
elif index is not None and (offset is None or count is None):
assert index.device == self.emb.device
self.emb[index] = x.to(self.emb.device)
elif offset is None or count is None:
assert n_id.device == self.emb.device
self.emb[n_id] = x.to(self.emb.device)
else:
x_o = 0
else: # Push in chunks:
src_o = 0
x = x.to(self.emb.device)
for o, c, in zip(offset.tolist(), count.tolist()):
self.emb[o:o + c] = x[x_o:x_o + c]
x_o += c
for dst_o, c, in zip(offset.tolist(), count.tolist()):
self.emb[dst_o:dst_o + c] = x[src_o:src_o + c]
src_o += c
def forward(self, *args, **kwargs):
""""""
......
......@@ -40,12 +40,8 @@ class SubgraphLoader(DataLoader):
batches = [(i, batches[i]) for i in range(len(batches))]
if batch_size > 1:
super(SubgraphLoader, self).__init__(
batches,
batch_size=batch_size,
collate_fn=self.compute_subgraph,
**kwargs,
)
super().__init__(batches, batch_size=batch_size,
collate_fn=self.compute_subgraph, **kwargs)
else: # If `batch_size=1`, we pre-process the subgraph generation:
if log:
......@@ -59,12 +55,8 @@ class SubgraphLoader(DataLoader):
if log:
print(f'Done! [{time.perf_counter() - t:.2f}s]')
super(SubgraphLoader, self).__init__(
data_list,
batch_size=batch_size,
collate_fn=lambda x: x[0],
**kwargs,
)
super().__init__(data_list, batch_size=batch_size,
collate_fn=lambda x: x[0], **kwargs)
def compute_subgraph(self, batches: List[Tuple[int, Tensor]]) -> SubData:
batch_ids, n_ids = zip(*batches)
......@@ -112,13 +104,5 @@ class EvalSubgraphLoader(SubgraphLoader):
if int(ptr[-1]) != data.num_nodes:
ptr = torch.cat([ptr, torch.tensor(data.num_nodes)], dim=0)
super(EvalSubgraphLoader, self).__init__(
data=data,
ptr=ptr,
batch_size=1,
bipartite=bipartite,
log=log,
shuffle=False,
num_workers=0,
**kwargs,
)
super().__init__(data=data, ptr=ptr, batch_size=1, bipartite=bipartite,
log=log, shuffle=False, num_workers=0, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment