"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "930c8fdcb7bac2fc3e0f0d10caea0180a462b762"
Unverified Commit 9eaace92 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

rename create_format_ to create_formats_ (#2126)

parent 745078e7
......@@ -240,7 +240,7 @@ source-destination array pairs. An example is given as follows:
dataloader = dgl.dataloading.EdgeDataLoader(
g, train_eid_dict, sampler,
negative_sampler=negative_sampler=NegativeSampler(g, 5),
negative_sampler=NegativeSampler(g, 5),
batch_size=1024,
shuffle=True,
drop_last=False,
......
......@@ -403,8 +403,8 @@ if __name__ == '__main__':
run(0, n_gpus, args, devices, dataset)
# multi gpu
else:
dataset.train_enc_graph.create_format_()
dataset.train_dec_graph.create_format_()
dataset.train_enc_graph.create_formats_()
dataset.train_dec_graph.create_formats_()
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run), args=(proc_id, n_gpus, args, devices, dataset))
......
......@@ -317,7 +317,7 @@ if __name__ == '__main__':
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
g.ndata['features'] = features
g.create_format_()
g.create_formats_()
# Pack data
data = train_mask, val_mask, in_feats, labels, n_classes, g
......
......@@ -384,7 +384,7 @@ if __name__ == '__main__':
g.ndata['features'] = features.share_memory_()
create_history_storage(g, args, n_classes)
g.create_format_()
g.create_formats_()
# Pack data
data = train_mask, val_mask, in_feats, labels, n_classes, g
......
......@@ -229,9 +229,9 @@ if __name__ == '__main__':
else:
train_g = val_g = test_g = g
train_g.create_format_()
val_g.create_format_()
test_g.create_format_()
train_g.create_formats_()
val_g.create_formats_()
test_g.create_formats_()
# Pack data
data = in_feats, n_classes, train_g, val_g, test_g
......
......@@ -258,9 +258,9 @@ if __name__ == '__main__':
else:
train_g = val_g = test_g = g
train_g.create_format_()
val_g.create_format_()
test_g.create_format_()
train_g.create_formats_()
val_g.create_formats_()
test_g.create_formats_()
# Pack data
data = in_feats, n_classes, train_g, val_g, test_g
......
......@@ -298,7 +298,7 @@ def main(args, devices):
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
g.ndata['features'] = features
g.create_format_()
g.create_formats_()
# Pack data
data = train_mask, val_mask, test_mask, in_feats, labels, n_classes, g
......
......@@ -67,5 +67,5 @@ def subgraph_collate_fn(g, batch):
g1.ndata['feat'] = g.ndata['feat'][nid]
g1.ndata['labels'] = g.ndata['labels'][nid]
g1.ndata['train_mask'] = g.ndata['train_mask'][nid]
g1.create_format_()
g1.create_formats_()
return g1
......@@ -243,7 +243,7 @@ if __name__ == '__main__':
in_feats = graph.ndata['feat'].shape[1]
n_classes = (labels.max() + 1).item()
graph.create_format_()
graph.create_formats_()
# Pack data
data = train_idx, val_idx, test_idx, in_feats, labels, n_classes, graph, args.head
......
......@@ -234,7 +234,7 @@ if __name__ == '__main__':
in_feats = graph.ndata['feat'].shape[1]
n_classes = (labels.max() + 1).item()
graph.create_format_()
graph.create_formats_()
# Pack data
data = train_idx, val_idx, test_idx, in_feats, labels, n_classes, graph
......
......@@ -245,7 +245,7 @@ class BlockSampler(object):
seed_nodes = {ntype: block.srcnodes[ntype].data[NID] for ntype in block.srctypes}
# Pre-generate CSR format so that it can be used in training directly
block.create_format_()
block.create_formats_()
blocks.insert(0, block)
return blocks
......
......@@ -5241,7 +5241,7 @@ class DGLHeteroGraph(object):
ret._graph = self._graph.formats(formats)
return ret
def create_format_(self):
def create_formats_(self):
r"""Create all sparse matrices allowed for the graph.
By default, we create sparse matrices for a graph only when necessary.
......@@ -5261,7 +5261,7 @@ class DGLHeteroGraph(object):
>>> g = dgl.graph(([0, 0, 1], [2, 3, 2]))
>>> g.format()
{'created': ['coo'], 'not created': ['csr', 'csc']}
>>> g.create_format_()
>>> g.create_formats_()
>>> g.format()
{'created': ['coo', 'csr', 'csc'], 'not created': []}
......@@ -5275,14 +5275,14 @@ class DGLHeteroGraph(object):
... })
>>> g.format()
{'created': ['coo'], 'not created': ['csr', 'csc']}
>>> g.create_format_()
>>> g.create_formats_()
>>> g.format()
{'created': ['coo', 'csr', 'csc'], 'not created': []}
"""
if self.num_edges() == 0:
return 0
return self._graph.create_format_()
return self._graph.create_formats_()
def astype(self, idtype):
"""Cast this graph to use another ID type.
......
......@@ -909,7 +909,7 @@ class HeteroGraphIndex(ObjectBase):
formats = [formats]
return _CAPI_DGLHeteroGetFormatGraph(self, formats)
def create_format_(self):
def create_formats_(self):
"""Create all sparse matrices allowed for the graph."""
return _CAPI_DGLHeteroCreateFormat(self)
......
......@@ -1714,7 +1714,7 @@ def test_format(idtype):
assert g.formats()['created'] == ['coo']
g1 = g.formats(['coo', 'csr', 'csc'])
assert len(g1.formats()['created']) + len(g1.formats()['not created']) == 3
g1.create_format_()
g1.create_formats_()
assert len(g1.formats()['created']) == 3
assert g.formats()['created'] == ['coo']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment