Unverified Commit 264d96cd authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[bugfix] Fix a bunch of examples to be compatible with dgl 0.5 (#1957)



* upd

* upd

* upd

* upd

* upd

* upd

* fix pinsage also

* upd

* upd

* upd
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-29-3.us-east-2.compute.internal>
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent dcf46412
...@@ -16,8 +16,8 @@ class ItemToItemBatchSampler(IterableDataset): ...@@ -16,8 +16,8 @@ class ItemToItemBatchSampler(IterableDataset):
self.g = g self.g = g
self.user_type = user_type self.user_type = user_type
self.item_type = item_type self.item_type = item_type
self.user_to_item_etype = list(g.metagraph[user_type][item_type])[0] self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
self.item_to_user_etype = list(g.metagraph[item_type][user_type])[0] self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
self.batch_size = batch_size self.batch_size = batch_size
def __iter__(self): def __iter__(self):
...@@ -38,8 +38,8 @@ class NeighborSampler(object): ...@@ -38,8 +38,8 @@ class NeighborSampler(object):
self.g = g self.g = g
self.user_type = user_type self.user_type = user_type
self.item_type = item_type self.item_type = item_type
self.user_to_item_etype = list(g.metagraph[user_type][item_type])[0] self.user_to_item_etype = list(g.metagraph()[user_type][item_type])[0]
self.item_to_user_etype = list(g.metagraph[item_type][user_type])[0] self.item_to_user_etype = list(g.metagraph()[item_type][user_type])[0]
self.samplers = [ self.samplers = [
dgl.sampling.PinSAGESampler(g, item_type, user_type, random_walk_length, dgl.sampling.PinSAGESampler(g, item_type, user_type, random_walk_length,
random_walk_restart_prob, num_random_walks, num_neighbors) random_walk_restart_prob, num_random_walks, num_neighbors)
......
...@@ -42,7 +42,7 @@ def main(args): ...@@ -42,7 +42,7 @@ def main(args):
cuda = False cuda = False
else: else:
cuda = True cuda = True
g = g.to(args.gpu) g = g.int().to(args.gpu)
features = g.ndata['feat'] features = g.ndata['feat']
labels = g.ndata['label'] labels = g.ndata['label']
......
...@@ -36,7 +36,7 @@ def main(args): ...@@ -36,7 +36,7 @@ def main(args):
cuda = False cuda = False
else: else:
cuda = True cuda = True
g = g.to(args.gpu) g = g.int().to(args.gpu)
features = g.ndata['feat'] features = g.ndata['feat']
labels = g.ndata['label'] labels = g.ndata['label']
......
# Transformer in DGL # Transformer in DGL
**This example is out-dated, please refer to [BP-Transformer](http://github.com/yzh119/bpt) for efficient (Sparse) Transformer implementation in DGL.**
In this example we implement the [Transformer](https://arxiv.org/pdf/1706.03762.pdf) with ACT in DGL. In this example we implement the [Transformer](https://arxiv.org/pdf/1706.03762.pdf) with ACT in DGL.
The folder contains training module and inferencing module (beam decoder) for Transformer. The folder contains training module and inferencing module (beam decoder) for Transformer.
......
...@@ -126,7 +126,7 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -126,7 +126,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
g.ndata['test_mask'] = generate_mask_tensor(test_mask) g.ndata['test_mask'] = generate_mask_tensor(test_mask)
g.ndata['label'] = F.tensor(labels) g.ndata['label'] = F.tensor(labels)
g.ndata['feat'] = F.tensor(_preprocess_features(features), dtype=F.data_type_dict['float32']) g.ndata['feat'] = F.tensor(_preprocess_features(features), dtype=F.data_type_dict['float32'])
self._num_labels = onehot_labels.shape[1] self._num_classes = onehot_labels.shape[1]
self._labels = labels self._labels = labels
self._g = g self._g = g
...@@ -135,7 +135,7 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -135,7 +135,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
print(' NumNodes: {}'.format(self._g.number_of_nodes())) print(' NumNodes: {}'.format(self._g.number_of_nodes()))
print(' NumEdges: {}'.format(self._g.number_of_edges())) print(' NumEdges: {}'.format(self._g.number_of_edges()))
print(' NumFeats: {}'.format(self._g.ndata['feat'].shape[1])) print(' NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
print(' NumClasses: {}'.format(self.num_labels)) print(' NumClasses: {}'.format(self.num_classes))
print(' NumTrainingSamples: {}'.format( print(' NumTrainingSamples: {}'.format(
F.nonzero_1d(self._g.ndata['train_mask']).shape[0])) F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
print(' NumValidationSamples: {}'.format( print(' NumValidationSamples: {}'.format(
...@@ -161,7 +161,7 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -161,7 +161,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
info_path = os.path.join(self.save_path, info_path = os.path.join(self.save_path,
self.save_name + '.pkl') self.save_name + '.pkl')
save_graphs(str(graph_path), self._g) save_graphs(str(graph_path), self._g)
save_info(str(info_path), {'num_labels': self.num_labels}) save_info(str(info_path), {'num_classes': self.num_classes})
def load(self): def load(self):
graph_path = os.path.join(self.save_path, graph_path = os.path.join(self.save_path,
...@@ -181,7 +181,7 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -181,7 +181,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
graph = to_networkx(graph) graph = to_networkx(graph)
self._graph = nx.DiGraph(graph) self._graph = nx.DiGraph(graph)
self._num_labels = info['num_labels'] self._num_classes = info['num_classes']
self._g.ndata['train_mask'] = generate_mask_tensor(self._g.ndata['train_mask'].numpy()) self._g.ndata['train_mask'] = generate_mask_tensor(self._g.ndata['train_mask'].numpy())
self._g.ndata['val_mask'] = generate_mask_tensor(self._g.ndata['val_mask'].numpy()) self._g.ndata['val_mask'] = generate_mask_tensor(self._g.ndata['val_mask'].numpy())
self._g.ndata['test_mask'] = generate_mask_tensor(self._g.ndata['test_mask'].numpy()) self._g.ndata['test_mask'] = generate_mask_tensor(self._g.ndata['test_mask'].numpy())
...@@ -191,7 +191,7 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -191,7 +191,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
print(' NumNodes: {}'.format(self._g.number_of_nodes())) print(' NumNodes: {}'.format(self._g.number_of_nodes()))
print(' NumEdges: {}'.format(self._g.number_of_edges())) print(' NumEdges: {}'.format(self._g.number_of_edges()))
print(' NumFeats: {}'.format(self._g.ndata['feat'].shape[1])) print(' NumFeats: {}'.format(self._g.ndata['feat'].shape[1]))
print(' NumClasses: {}'.format(self.num_labels)) print(' NumClasses: {}'.format(self.num_classes))
print(' NumTrainingSamples: {}'.format( print(' NumTrainingSamples: {}'.format(
F.nonzero_1d(self._g.ndata['train_mask']).shape[0])) F.nonzero_1d(self._g.ndata['train_mask']).shape[0]))
print(' NumValidationSamples: {}'.format( print(' NumValidationSamples: {}'.format(
...@@ -212,7 +212,12 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -212,7 +212,12 @@ class CitationGraphDataset(DGLBuiltinDataset):
@property @property
def num_labels(self): def num_labels(self):
return self._num_labels deprecate_property('dataset.num_labels', 'dataset.num_classes')
return self.num_classes
@property
def num_classes(self):
return self._num_classes
""" Citation graph is used in many examples """ Citation graph is used in many examples
We preserve these properties for compatability. We preserve these properties for compatability.
...@@ -339,7 +344,7 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -339,7 +344,7 @@ class CoraGraphDataset(CitationGraphDataset):
Attributes Attributes
---------- ----------
num_labels: int num_classes: int
Number of label classes Number of label classes
graph: networkx.DiGraph graph: networkx.DiGraph
Graph structure Graph structure
...@@ -362,7 +367,7 @@ class CoraGraphDataset(CitationGraphDataset): ...@@ -362,7 +367,7 @@ class CoraGraphDataset(CitationGraphDataset):
-------- --------
>>> dataset = CoraGraphDataset() >>> dataset = CoraGraphDataset()
>>> g = dataset[0] >>> g = dataset[0]
>>> num_class = g.num_labels >>> num_class = g.num_classes
>>> >>>
>>> # get node feature >>> # get node feature
>>> feat = g.ndata['feat'] >>> feat = g.ndata['feat']
...@@ -479,7 +484,7 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -479,7 +484,7 @@ class CiteseerGraphDataset(CitationGraphDataset):
Attributes Attributes
---------- ----------
num_labels: int num_classes: int
Number of label classes Number of label classes
graph: networkx.DiGraph graph: networkx.DiGraph
Graph structure Graph structure
...@@ -505,7 +510,7 @@ class CiteseerGraphDataset(CitationGraphDataset): ...@@ -505,7 +510,7 @@ class CiteseerGraphDataset(CitationGraphDataset):
-------- --------
>>> dataset = CiteseerGraphDataset() >>> dataset = CiteseerGraphDataset()
>>> g = dataset[0] >>> g = dataset[0]
>>> num_class = g.num_labels >>> num_class = g.num_classes
>>> >>>
>>> # get node feature >>> # get node feature
>>> feat = g.ndata['feat'] >>> feat = g.ndata['feat']
...@@ -622,7 +627,7 @@ class PubmedGraphDataset(CitationGraphDataset): ...@@ -622,7 +627,7 @@ class PubmedGraphDataset(CitationGraphDataset):
Attributes Attributes
---------- ----------
num_labels: int num_classes: int
Number of label classes Number of label classes
graph: networkx.DiGraph graph: networkx.DiGraph
Graph structure Graph structure
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment