"tests/python/tensorflow/test_nn.py" did not exist on "7c598aac6c25fbee53e52f6bd54c2fd04bad2151"
Unverified Commit 264d96cd authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[bugfix] Fix a bunch of examples to be compatible with dgl 0.5 (#1957)



* upd

* upd

* upd

* upd

* upd

* upd

* fix pinsage also

* upd

* upd

* upd
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-29-3.us-east-2.compute.internal>
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent dcf46412
......@@ -16,8 +16,8 @@ class ItemToItemBatchSampler(IterableDataset):
self.g = g
self.user_type = user_type
self.item_type = item_type
self.user_to_item_etype = list(g.metagraph[user_type][item_type])[0]
self.item_to_user_etype = list(g.metagraph[item_type][user_type])[0]
self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
self.batch_size = batch_size
def __iter__(self):
......@@ -38,8 +38,8 @@ class NeighborSampler(object):
self.g = g
self.user_type = user_type
self.item_type = item_type
self.user_to_item_etype = list(g.metagraph[user_type][item_type])[0]
self.item_to_user_etype = list(g.metagraph[item_type][user_type])[0]
self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
self.samplers = [
dgl.sampling.PinSAGESampler(g, item_type, user_type, random_walk_length,
random_walk_restart_prob, num_random_walks, num_neighbors)
......
......@@ -42,7 +42,7 @@ def main(args):
cuda = False
else:
cuda = True
g = g.to(args.gpu)
g = g.int().to(args.gpu)
features = g.ndata['feat']
labels = g.ndata['label']
......
......@@ -36,7 +36,7 @@ def main(args):
cuda = False
else:
cuda = True
g = g.to(args.gpu)
g = g.int().to(args.gpu)
features = g.ndata['feat']
labels = g.ndata['label']
......
# Transformer in DGL
**This example is out-dated, please refer to [BP-Transformer](http://github.com/yzh119/bpt) for efficient (Sparse) Transformer implementation in DGL.**
In this example we implement the [Transformer](https://arxiv.org/pdf/1706.03762.pdf) with ACT in DGL.
The folder contains training module and inferencing module (beam decoder) for Transformer.
......
......@@ -126,7 +126,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
g.ndata['test_mask'] = generate_mask_tensor(test_mask)
g.ndata['label'] = F.tensor(labels)
g.ndata['feat'] = F.tensor(_preprocess_features(features), dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1]
self._num_classes = onehot_labels.shape[1]
self._labels = labels
self._g = g
......@@ -135,7 +135,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
print(' NumNodes: {}'.format(self._g.number_of_nodes()))
print(' NumEdges: {}'.format(self._g.number_of_edges()))
print(' NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumClasses: {}'.format(self.num_classes))
print(' NumTrainingSamples: {}'.format(
F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
print(' NumValidationSamples: {}'.format(
......@@ -161,7 +161,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
info_path = os.path.join(self.save_path,
self.save_name + '.pkl')
save_graphs(str(graph_path), self._g)
save_info(str(info_path), {'num_labels': self.num_labels})
save_info(str(info_path), {'num_classes': self.num_classes})
def load(self):
graph_path = os.path.join(self.save_path,
......@@ -181,7 +181,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
graph = to_networkx(graph)
self._graph = nx.DiGraph(graph)
self._num_labels = info['num_labels']
self._num_classes = info['num_classes']
self._g.ndata['train_mask'] = generate_mask_tensor(self._g.ndata['train_mask'].numpy())
self._g.ndata['val_mask'] = generate_mask_tensor(self._g.ndata['val_mask'].numpy())
self._g.ndata['test_mask'] = generate_mask_tensor(self._g.ndata['test_mask'].numpy())
......@@ -191,7 +191,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
print(' NumNodes: {}'.format(self._g.number_of_nodes()))
print(' NumEdges: {}'.format(self._g.number_of_edges()))
print(' NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
print(' NumClasses: {}'.format(self.num_labels))
print(' NumClasses: {}'.format(self.num_classes))
print(' NumTrainingSamples: {}'.format(
F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
print(' NumValidationSamples: {}'.format(
......@@ -212,7 +212,12 @@ class CitationGraphDataset(DGLBuiltinDataset):
@property
def num_labels(self):
return self._num_labels
deprecate_property('dataset.num_labels', 'dataset.num_classes')
return self.num_classes
@property
def num_classes(self):
return self._num_classes
""" Citation graph is used in many examples
We preserve these properties for compatability.
......@@ -339,7 +344,7 @@ class CoraGraphDataset(CitationGraphDataset):
Attributes
----------
num_labels: int
num_classes: int
Number of label classes
graph: networkx.DiGraph
Graph structure
......@@ -362,7 +367,7 @@ class CoraGraphDataset(CitationGraphDataset):
--------
>>> dataset = CoraGraphDataset()
>>> g = dataset[0]
>>> num_class = g.num_labels
>>> num_class = g.num_classes
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
......@@ -479,7 +484,7 @@ class CiteseerGraphDataset(CitationGraphDataset):
Attributes
----------
num_labels: int
num_classes: int
Number of label classes
graph: networkx.DiGraph
Graph structure
......@@ -505,7 +510,7 @@ class CiteseerGraphDataset(CitationGraphDataset):
--------
>>> dataset = CiteseerGraphDataset()
>>> g = dataset[0]
>>> num_class = g.num_labels
>>> num_class = g.num_classes
>>>
>>> # get node feature
>>> feat = g.ndata['feat']
......@@ -622,7 +627,7 @@ class PubmedGraphDataset(CitationGraphDataset):
Attributes
----------
num_labels: int
num_classes: int
Number of label classes
graph: networkx.DiGraph
Graph structure
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment