"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4fc7084875ac672980c4e0207fa325c8c418a88e"
Commit 3a0f86a6 authored by Mufei Li's avatar Mufei Li Committed by Minjie Wang
Browse files

[Doc] data (#198)

* Fix dataset

* Track data APIs
parent 485f6d3a
.. _apidata:
Dataset
=======
.. currentmodule:: dgl.data
Utils
-----
.. autosummary::
:toctree: ../../generated/
utils.get_download_dir
utils.download
utils.check_sha1
utils.extract_archive
Dataset Classes
---------------
Stanford sentiment treebank dataset
```````````````````````````````````
For more information about the dataset, see `Sentiment Analysis <https://nlp.stanford.edu/sentiment/index.html>`__.
.. autoclass:: SST
:members: __getitem__, __len__
\ No newline at end of file
...@@ -10,3 +10,4 @@ API Reference ...@@ -10,3 +10,4 @@ API Reference
traversal traversal
propagate propagate
udf udf
data
...@@ -26,14 +26,14 @@ class SST(object): ...@@ -26,14 +26,14 @@ class SST(object):
"""Stanford Sentiment Treebank dataset. """Stanford Sentiment Treebank dataset.
Each sample is the constituency tree of a sentence. The leaf nodes Each sample is the constituency tree of a sentence. The leaf nodes
represent words. The word is a int value stored in the "x" feature field. represent words. The word is a int value stored in the ``x`` feature field.
The non-leaf node has a special value PAD_WORD. The non-leaf node has a special value ``PAD_WORD`` in the ``x`` field.
Each node also has a sentiment annotation: 5 classes (very negative, Each node also has a sentiment annotation: 5 classes (very negative,
negative, neutral, positive and very positive). The sentiment label is a negative, neutral, positive and very positive). The sentiment label is a
int value stored in the "y" feature field. int value stored in the ``y`` feature field.
.. note:: .. note::
This dataset class is compatible with pytorch's Dataset class. This dataset class is compatible with pytorch's :class:`Dataset` class.
.. note:: .. note::
All the samples will be loaded and preprocessed in the memory first. All the samples will be loaded and preprocessed in the memory first.
...@@ -41,7 +41,7 @@ class SST(object): ...@@ -41,7 +41,7 @@ class SST(object):
Parameters Parameters
---------- ----------
mode : str, optional mode : str, optional
Can be 'train', 'val', 'test'. Which data file to use. Can be ``'train'``, ``'val'``, ``'test'`` and specifies which data file to use.
vocab_file : str, optional vocab_file : str, optional
Optional vocabulary file. Optional vocabulary file.
""" """
...@@ -120,9 +120,28 @@ class SST(object): ...@@ -120,9 +120,28 @@ class SST(object):
return ret return ret
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the tree with index idx.
Parameters
----------
idx : int
Tree index.
Returns
-------
dgl.DGLGraph
Tree.
"""
return self.trees[idx] return self.trees[idx]
def __len__(self): def __len__(self):
"""Get the number of trees in the dataset.
Returns
-------
int
Number of trees.
"""
return len(self.trees) return len(self.trees)
@property @property
......
...@@ -14,24 +14,24 @@ except ImportError: ...@@ -14,24 +14,24 @@ except ImportError:
requests = requests_failed_to_import requests = requests_failed_to_import
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download an given URL """Download a given URL.
Codes borrowed from mxnet/gluon/utils.py Codes borrowed from mxnet/gluon/utils.py
Parameters Parameters
---------- ----------
url : str url : str
URL to download URL to download.
path : str, optional path : str, optional
Destination path to store downloaded file. By default stores to the Destination path to store downloaded file. By default stores to the
current directory with same name as in url. current directory with the same name as in url.
overwrite : bool, optional overwrite : bool, optional
Whether to overwrite destination file if already exists. Whether to overwrite the destination file if it already exists.
sha1_hash : str, optional sha1_hash : str, optional
Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
but doesn't match. but doesn't match.
retries : integer, default 5 retries : integer, default 5
The number of times to attempt the download in case of failure or non 200 return codes The number of times to attempt downloading in case of failure or non 200 return codes.
verify_ssl : bool, default True verify_ssl : bool, default True
Verify SSL certificates. Verify SSL certificates.
...@@ -101,6 +101,7 @@ def check_sha1(filename, sha1_hash): ...@@ -101,6 +101,7 @@ def check_sha1(filename, sha1_hash):
Path to the file. Path to the file.
sha1_hash : str sha1_hash : str
Expected sha1 hash in hexadecimal digits. Expected sha1 hash in hexadecimal digits.
Returns Returns
------- -------
bool bool
...@@ -117,14 +118,14 @@ def check_sha1(filename, sha1_hash): ...@@ -117,14 +118,14 @@ def check_sha1(filename, sha1_hash):
return sha1.hexdigest() == sha1_hash return sha1.hexdigest() == sha1_hash
def extract_archive(file, target_dir): def extract_archive(file, target_dir):
"""Extract archive file """Extract archive file.
Parameters Parameters
---------- ----------
file : str file : str
Absolute path of the archive file. Absolute path of the archive file.
target_dir : str target_dir : str
Target directory of the archive to be uncompressed Target directory of the archive to be uncompressed.
""" """
if os.path.exists(target_dir): if os.path.exists(target_dir):
return return
...@@ -139,7 +140,13 @@ def extract_archive(file, target_dir): ...@@ -139,7 +140,13 @@ def extract_archive(file, target_dir):
archive.close() archive.close()
def get_download_dir(): def get_download_dir():
"""Get the absolute path to the download directory.""" """Get the absolute path to the download directory.
Returns
-------
dirname : str
Path to the download directory
"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
dirname = os.path.join(curr_path, '../../../_download') dirname = os.path.join(curr_path, '../../../_download')
if not os.path.exists(dirname): if not os.path.exists(dirname):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment