Unverified Commit a260a6e6 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Example] Fix GATNE-T (#2061)



* Update README.md

* Update

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-1-5.us-west-2.compute.internal>
parent 3af0e91c
......@@ -8,17 +8,24 @@ Requirements
------------
- requirements
``bash
pip install requirements
``
```bash
pip install -r requirements.txt
```
Datasets
--------
* [example](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/example.zip)
* [amazon](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/amazon.zip)
* [youtube](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/youtube.zip)
* [twitter](https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/twitter.zip)
To prepare the datasets:
1. ```bash
mkdir data
cd data
```
2. Download datasets from the following links:
- example: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/example.zip
- amazon: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/amazon.zip
- youtube: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/youtube.zip
- twitter: https://s3.us-west-2.amazonaws.com/dgl-data/dataset/recsys/GATNE/twitter.zip
3. Unzip the datasets
Training
--------
......
......@@ -39,11 +39,12 @@ def get_graph(network_data, vocab):
for edge_type in network_data:
tmp_data = network_data[edge_type]
edges = []
src = []
dst = []
for edge in tmp_data:
edges.append((vocab[edge[0]], vocab[edge[1]]))
edges.append((vocab[edge[1]], vocab[edge[0]]))
data_dict[(node_type, edge_type, node_type)] = zip(*edges)
src.extend([vocab[edge[0]], vocab[edge[1]]])
dst.extend([vocab[edge[1]], vocab[edge[0]]])
data_dict[(node_type, edge_type, node_type)] = (src, dst)
graph = dgl.heterograph(data_dict, num_nodes_dict)
return graph
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment