Commit 5278220f authored by Haibin Lin's avatar Haibin Lin Committed by Zihao Ye
Browse files

update s3 links (#215)

parent ac932c66
Environment Variables
=====================
Backend Options
---------------
* DGLBACKEND
* Values: String (default='pytorch')
* The backend deep lerarning framework for DGL.
* Choices:
* 'pytorch': use PyTorch as the backend implentation.
* 'mxnet': use Apache MXNet as the backend implementation.
Data Repository
---------------
* DGL_REPO
* Values: String (default='https://s3.us-east-2.amazonaws.com/dgl.ai/')
* The repository url to be used for DGL datasets and pre-trained models.
* Suggested values:
* 'https://s3.us-east-2.amazonaws.com/dgl.ai/': DGL repo for U.S.
* 'https://s3-ap-southeast-1.amazonaws.com/dgl.ai.asia/': DGL repo for Asia
......@@ -13,11 +13,11 @@ import rdflib as rdf
import pandas as pd
from collections import Counter
from dgl.data.utils import download, extract_archive, get_download_dir
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
np.random.seed(123)
_downlaod_prefix = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/'
_downlaod_prefix = _get_dgl_url('dataset/')
class RGCNEntityDataset(object):
"""RGCN Entity Classification dataset
......
"""Cora, citeseer, pubmed dataset.
(lingfan): following dataset loading and preprocessing code from tkipf/gcn
......@@ -12,12 +11,12 @@ import networkx as nx
import scipy.sparse as sp
import os, sys
from .utils import download, extract_archive, get_download_dir
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
_urls = {
'cora' : 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/cora.zip',
'citeseer' : 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/citeseer.zip',
'pubmed' : 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/pubmed.zip',
'cora' : 'dataset/cora.zip',
'citeseer' : 'dataset/citeseer.zip',
'pubmed' : 'dataset/pubmed.zip',
}
class CitationGraphDataset(object):
......@@ -25,7 +24,7 @@ class CitationGraphDataset(object):
self.name = name
self.dir = get_download_dir()
self.zip_file_path='{}/{}.zip'.format(self.dir, name)
download(_urls[name], path=self.zip_file_path)
download(_get_dgl_url(_urls[name]), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/{}'.format(self.dir, name))
self._load()
......
......@@ -14,10 +14,10 @@ import numpy as np
import os
import dgl
import dgl.backend as F
from dgl.data.utils import download, extract_archive, get_download_dir
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
_urls = {
'sst' : 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/sst.zip',
'sst' : 'dataset/sst.zip',
}
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
......@@ -54,7 +54,7 @@ class SST(object):
self.pretrained_file = 'glove.840B.300d.txt' if mode == 'train' else ''
self.pretrained_emb = None
self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file
download(_urls['sst'], path=self.zip_file_path)
download(_get_dgl_url(_urls['sst']), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
self.trees = []
self.num_classes = 5
......
......@@ -13,6 +13,15 @@ except ImportError:
pass
requests = requests_failed_to_import
def _get_dgl_url(file_url):
"""Get DGL online url for download."""
dgl_repo_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/'
repo_url = os.environ.get('DGL_REPO', dgl_repo_url)
if repo_url[-1] != '/':
repo_url = repo_url + '/'
return repo_url + file_url
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download a given URL.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment