Commit 5278220f authored by Haibin Lin's avatar Haibin Lin Committed by Zihao Ye
Browse files

update s3 links (#215)

parent ac932c66
Environment Variables
=====================
Backend Options
---------------
* DGLBACKEND
* Values: String (default='pytorch')
* The backend deep lerarning framework for DGL.
* Choices:
* 'pytorch': use PyTorch as the backend implentation.
* 'mxnet': use Apache MXNet as the backend implementation.
Data Repository
---------------
* DGL_REPO
* Values: String (default='https://s3.us-east-2.amazonaws.com/dgl.ai/')
* The repository url to be used for DGL datasets and pre-trained models.
* Suggested values:
* 'https://s3.us-east-2.amazonaws.com/dgl.ai/': DGL repo for U.S.
* 'https://s3-ap-southeast-1.amazonaws.com/dgl.ai.asia/': DGL repo for Asia
...@@ -13,11 +13,11 @@ import rdflib as rdf ...@@ -13,11 +13,11 @@ import rdflib as rdf
import pandas as pd import pandas as pd
from collections import Counter from collections import Counter
from dgl.data.utils import download, extract_archive, get_download_dir from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
np.random.seed(123) np.random.seed(123)
_downlaod_prefix = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/' _downlaod_prefix = _get_dgl_url('dataset/')
class RGCNEntityDataset(object): class RGCNEntityDataset(object):
"""RGCN Entity Classification dataset """RGCN Entity Classification dataset
......
"""Cora, citeseer, pubmed dataset. """Cora, citeseer, pubmed dataset.
(lingfan): following dataset loading and preprocessing code from tkipf/gcn (lingfan): following dataset loading and preprocessing code from tkipf/gcn
...@@ -12,12 +11,12 @@ import networkx as nx ...@@ -12,12 +11,12 @@ import networkx as nx
import scipy.sparse as sp import scipy.sparse as sp
import os, sys import os, sys
from .utils import download, extract_archive, get_download_dir from .utils import download, extract_archive, get_download_dir, _get_dgl_url
_urls = { _urls = {
'cora' : 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/cora.zip', 'cora' : 'dataset/cora.zip',
'citeseer' : 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/citeseer.zip', 'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/pubmed.zip', 'pubmed' : 'dataset/pubmed.zip',
} }
class CitationGraphDataset(object): class CitationGraphDataset(object):
...@@ -25,7 +24,7 @@ class CitationGraphDataset(object): ...@@ -25,7 +24,7 @@ class CitationGraphDataset(object):
self.name = name self.name = name
self.dir = get_download_dir() self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name) self.zip_file_path='{}/{}.zip'.format(self.dir, name)
download(_urls[name], path=self.zip_file_path) download(_get_dgl_url(_urls[name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, name)) extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, name))
self._load() self._load()
......
...@@ -14,10 +14,10 @@ import numpy as np ...@@ -14,10 +14,10 @@ import numpy as np
import os import os
import dgl import dgl
import dgl.backend as F import dgl.backend as F
from dgl.data.utils import download, extract_archive, get_download_dir from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
_urls = { _urls = {
'sst' : 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/sst.zip', 'sst' : 'dataset/sst.zip',
} }
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label']) SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
...@@ -54,7 +54,7 @@ class SST(object): ...@@ -54,7 +54,7 @@ class SST(object):
self.pretrained_file = 'glove.840B.300d.txt' if mode == 'train' else '' self.pretrained_file = 'glove.840B.300d.txt' if mode == 'train' else ''
self.pretrained_emb = None self.pretrained_emb = None
self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file
download(_urls['sst'], path=self.zip_file_path) download(_get_dgl_url(_urls['sst']), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir)) extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
self.trees = [] self.trees = []
self.num_classes = 5 self.num_classes = 5
......
...@@ -13,6 +13,15 @@ except ImportError: ...@@ -13,6 +13,15 @@ except ImportError:
pass pass
requests = requests_failed_to_import requests = requests_failed_to_import
def _get_dgl_url(file_url):
"""Get DGL online url for download."""
dgl_repo_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/'
repo_url = os.environ.get('DGL_REPO', dgl_repo_url)
if repo_url[-1] != '/':
repo_url = repo_url + '/'
return repo_url + file_url
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download a given URL. """Download a given URL.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment