Unverified Commit a260a6e6 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Example] Fix GATNE-T (#2061)



* Update README.md

* Update

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-1-5.us-west-2.compute.internal>
parent 3af0e91c
...@@ -8,17 +8,24 @@ Requirements ...@@ -8,17 +8,24 @@ Requirements
------------ ------------
- requirements - requirements
``bash ```bash
pip install requirements pip install -r requirements.txt
`` ```
Datasets Datasets
-------- --------
* [example](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/example.zip)
* [amazon](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/amazon.zip)
* [youtube](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/youtube.zip)
* [twitter](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/twitter.zip)
To prepare the datasets:
1. ```bash
mkdir data
cd data
```
2. Download datasets from the following links:
- example: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/example.zip
- amazon: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/amazon.zip
- youtube: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/youtube.zip
- twitter: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/twitter.zip
3. Unzip the datasets
Training Training
-------- --------
......
...@@ -39,11 +39,12 @@ def get_graph(network_data, vocab): ...@@ -39,11 +39,12 @@ def get_graph(network_data, vocab):
for edge_type in network_data: for edge_type in network_data:
tmp_data = network_data[edge_type] tmp_data = network_data[edge_type]
edges = [] src = []
dst = []
for edge in tmp_data: for edge in tmp_data:
edges.append((vocab[edge[0]], vocab[edge[1]])) src.extend([vocab[edge[0]], vocab[edge[1]]])
edges.append((vocab[edge[1]], vocab[edge[0]])) dst.extend([vocab[edge[1]], vocab[edge[0]]])
data_dict[(node_type, edge_type, node_type)] = zip(*edges) data_dict[(node_type, edge_type, node_type)] = (src, dst)
graph = dgl.heterograph(data_dict, num_nodes_dict) graph = dgl.heterograph(data_dict, num_nodes_dict)
return graph return graph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment