Unverified Commit 9eaace92 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

rename create_format_ to create_formats_ (#2126)

parent 745078e7
...@@ -240,7 +240,7 @@ source-destination array pairs. An example is given as follows: ...@@ -240,7 +240,7 @@ source-destination array pairs. An example is given as follows:
dataloader = dgl.dataloading.EdgeDataLoader( dataloader = dgl.dataloading.EdgeDataLoader(
g, train_eid_dict, sampler, g, train_eid_dict, sampler,
negative_sampler=negative_sampler=NegativeSampler(g, 5), negative_sampler=NegativeSampler(g, 5),
batch_size=1024, batch_size=1024,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
......
...@@ -403,8 +403,8 @@ if __name__ == '__main__': ...@@ -403,8 +403,8 @@ if __name__ == '__main__':
run(0, n_gpus, args, devices, dataset) run(0, n_gpus, args, devices, dataset)
# multi gpu # multi gpu
else: else:
dataset.train_enc_graph.create_format_() dataset.train_enc_graph.create_formats_()
dataset.train_dec_graph.create_format_() dataset.train_dec_graph.create_formats_()
procs = [] procs = []
for proc_id in range(n_gpus): for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run), args=(proc_id, n_gpus, args, devices, dataset)) p = mp.Process(target=thread_wrapped_func(run), args=(proc_id, n_gpus, args, devices, dataset))
......
...@@ -317,7 +317,7 @@ if __name__ == '__main__': ...@@ -317,7 +317,7 @@ if __name__ == '__main__':
train_mask = g.ndata['train_mask'] train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask'] val_mask = g.ndata['val_mask']
g.ndata['features'] = features g.ndata['features'] = features
g.create_format_() g.create_formats_()
# Pack data # Pack data
data = train_mask, val_mask, in_feats, labels, n_classes, g data = train_mask, val_mask, in_feats, labels, n_classes, g
......
...@@ -384,7 +384,7 @@ if __name__ == '__main__': ...@@ -384,7 +384,7 @@ if __name__ == '__main__':
g.ndata['features'] = features.share_memory_() g.ndata['features'] = features.share_memory_()
create_history_storage(g, args, n_classes) create_history_storage(g, args, n_classes)
g.create_format_() g.create_formats_()
# Pack data # Pack data
data = train_mask, val_mask, in_feats, labels, n_classes, g data = train_mask, val_mask, in_feats, labels, n_classes, g
......
...@@ -229,9 +229,9 @@ if __name__ == '__main__': ...@@ -229,9 +229,9 @@ if __name__ == '__main__':
else: else:
train_g = val_g = test_g = g train_g = val_g = test_g = g
train_g.create_format_() train_g.create_formats_()
val_g.create_format_() val_g.create_formats_()
test_g.create_format_() test_g.create_formats_()
# Pack data # Pack data
data = in_feats, n_classes, train_g, val_g, test_g data = in_feats, n_classes, train_g, val_g, test_g
......
...@@ -258,9 +258,9 @@ if __name__ == '__main__': ...@@ -258,9 +258,9 @@ if __name__ == '__main__':
else: else:
train_g = val_g = test_g = g train_g = val_g = test_g = g
train_g.create_format_() train_g.create_formats_()
val_g.create_format_() val_g.create_formats_()
test_g.create_format_() test_g.create_formats_()
# Pack data # Pack data
data = in_feats, n_classes, train_g, val_g, test_g data = in_feats, n_classes, train_g, val_g, test_g
......
...@@ -298,7 +298,7 @@ def main(args, devices): ...@@ -298,7 +298,7 @@ def main(args, devices):
val_mask = g.ndata['val_mask'] val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask'] test_mask = g.ndata['test_mask']
g.ndata['features'] = features g.ndata['features'] = features
g.create_format_() g.create_formats_()
# Pack data # Pack data
data = train_mask, val_mask, test_mask, in_feats, labels, n_classes, g data = train_mask, val_mask, test_mask, in_feats, labels, n_classes, g
......
...@@ -67,5 +67,5 @@ def subgraph_collate_fn(g, batch): ...@@ -67,5 +67,5 @@ def subgraph_collate_fn(g, batch):
g1.ndata['feat'] = g.ndata['feat'][nid] g1.ndata['feat'] = g.ndata['feat'][nid]
g1.ndata['labels'] = g.ndata['labels'][nid] g1.ndata['labels'] = g.ndata['labels'][nid]
g1.ndata['train_mask'] = g.ndata['train_mask'][nid] g1.ndata['train_mask'] = g.ndata['train_mask'][nid]
g1.create_format_() g1.create_formats_()
return g1 return g1
...@@ -243,7 +243,7 @@ if __name__ == '__main__': ...@@ -243,7 +243,7 @@ if __name__ == '__main__':
in_feats = graph.ndata['feat'].shape[1] in_feats = graph.ndata['feat'].shape[1]
n_classes = (labels.max() + 1).item() n_classes = (labels.max() + 1).item()
graph.create_format_() graph.create_formats_()
# Pack data # Pack data
data = train_idx, val_idx, test_idx, in_feats, labels, n_classes, graph, args.head data = train_idx, val_idx, test_idx, in_feats, labels, n_classes, graph, args.head
......
...@@ -234,7 +234,7 @@ if __name__ == '__main__': ...@@ -234,7 +234,7 @@ if __name__ == '__main__':
in_feats = graph.ndata['feat'].shape[1] in_feats = graph.ndata['feat'].shape[1]
n_classes = (labels.max() + 1).item() n_classes = (labels.max() + 1).item()
graph.create_format_() graph.create_formats_()
# Pack data # Pack data
data = train_idx, val_idx, test_idx, in_feats, labels, n_classes, graph data = train_idx, val_idx, test_idx, in_feats, labels, n_classes, graph
......
...@@ -245,7 +245,7 @@ class BlockSampler(object): ...@@ -245,7 +245,7 @@ class BlockSampler(object):
seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes} seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes}
# Pre-generate CSR format so that it can be used in training directly # Pre-generate CSR format so that it can be used in training directly
block.create_format_() block.create_formats_()
blocks.insert(0, block) blocks.insert(0, block)
return blocks return blocks
......
...@@ -5241,7 +5241,7 @@ class DGLHeteroGraph(object): ...@@ -5241,7 +5241,7 @@ class DGLHeteroGraph(object):
ret._graph = self._graph.formats(formats) ret._graph = self._graph.formats(formats)
return ret return ret
def create_format_(self): def create_formats_(self):
r"""Create all sparse matrices allowed for the graph. r"""Create all sparse matrices allowed for the graph.
By default, we create sparse matrices for a graph only when necessary. By default, we create sparse matrices for a graph only when necessary.
...@@ -5261,7 +5261,7 @@ class DGLHeteroGraph(object): ...@@ -5261,7 +5261,7 @@ class DGLHeteroGraph(object):
>>> g = dgl.graph(([0, 0, 1], [2, 3, 2])) >>> g = dgl.graph(([0, 0, 1], [2, 3, 2]))
>>> g.format() >>> g.format()
{'created': ['coo'], 'not created': ['csr', 'csc']} {'created': ['coo'], 'not created': ['csr', 'csc']}
>>> g.create_format_() >>> g.create_formats_()
>>> g.format() >>> g.format()
{'created': ['coo', 'csr', 'csc'], 'not created': []} {'created': ['coo', 'csr', 'csc'], 'not created': []}
...@@ -5275,14 +5275,14 @@ class DGLHeteroGraph(object): ...@@ -5275,14 +5275,14 @@ class DGLHeteroGraph(object):
... }) ... })
>>> g.format() >>> g.format()
{'created': ['coo'], 'not created': ['csr', 'csc']} {'created': ['coo'], 'not created': ['csr', 'csc']}
>>> g.create_format_() >>> g.create_formats_()
>>> g.format() >>> g.format()
{'created': ['coo', 'csr', 'csc'], 'not created': []} {'created': ['coo', 'csr', 'csc'], 'not created': []}
""" """
if self.num_edges() == 0: if self.num_edges() == 0:
return 0 return 0
return self._graph.create_format_() return self._graph.create_formats_()
def astype(self, idtype): def astype(self, idtype):
"""Cast this graph to use another ID type. """Cast this graph to use another ID type.
......
...@@ -909,7 +909,7 @@ class HeteroGraphIndex(ObjectBase): ...@@ -909,7 +909,7 @@ class HeteroGraphIndex(ObjectBase):
formats = [formats] formats = [formats]
return _CAPI_DGLHeteroGetFormatGraph(self, formats) return _CAPI_DGLHeteroGetFormatGraph(self, formats)
def create_format_(self): def create_formats_(self):
"""Create all sparse matrices allowed for the graph.""" """Create all sparse matrices allowed for the graph."""
return _CAPI_DGLHeteroCreateFormat(self) return _CAPI_DGLHeteroCreateFormat(self)
......
...@@ -1714,7 +1714,7 @@ def test_format(idtype): ...@@ -1714,7 +1714,7 @@ def test_format(idtype):
assert g.formats()['created'] == ['coo'] assert g.formats()['created'] == ['coo']
g1 = g.formats(['coo', 'csr', 'csc']) g1 = g.formats(['coo', 'csr', 'csc'])
assert len(g1.formats()['created']) + len(g1.formats()['not created']) == 3 assert len(g1.formats()['created']) + len(g1.formats()['not created']) == 3
g1.create_format_() g1.create_formats_()
assert len(g1.formats()['created']) == 3 assert len(g1.formats()['created']) == 3
assert g.formats()['created'] == ['coo'] assert g.formats()['created'] == ['coo']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment