Unverified Commit 001d7937 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] SSTDataset (#1918)

* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* SSTDataset

* Update tree.py
parent 73b9c6f1
......@@ -3,6 +3,7 @@ import time
import warnings
import zipfile
import os
import collections
os.environ['DGLBACKEND'] = 'mxnet'
os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round'
......@@ -16,13 +17,15 @@ import dgl.data as data
from tree_lstm import TreeLSTM
SSTBatch = collections.namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
def batcher(ctx):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return data.SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].as_in_context(ctx),
wordid=batch_trees.ndata['x'].as_in_context(ctx),
label=batch_trees.ndata['y'].as_in_context(ctx))
return SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].as_in_context(ctx),
wordid=batch_trees.ndata['x'].as_in_context(ctx),
label=batch_trees.ndata['y'].as_in_context(ctx))
return batcher_dev
def prepare_glove():
......@@ -52,30 +55,32 @@ def main(args):
else:
print('Requested GPU id {} was not found. Defaulting to CPU implementation'.format(args.gpu))
ctx = mx.cpu()
else:
ctx = mx.cpu()
if args.use_glove:
prepare_glove()
trainset = data.SST()
trainset = data.SSTDataset()
train_loader = gluon.data.DataLoader(dataset=trainset,
batch_size=args.batch_size,
batchify_fn=batcher(ctx),
shuffle=True,
num_workers=0)
devset = data.SST(mode='dev')
devset = data.SSTDataset(mode='dev')
dev_loader = gluon.data.DataLoader(dataset=devset,
batch_size=100,
batchify_fn=batcher(ctx),
shuffle=True,
num_workers=0)
testset = data.SST(mode='test')
testset = data.SSTDataset(mode='test')
test_loader = gluon.data.DataLoader(dataset=testset,
batch_size=100,
batchify_fn=batcher(ctx),
shuffle=False, num_workers=0)
model = TreeLSTM(trainset.num_vocabs,
model = TreeLSTM(trainset.vocab_size,
args.x_size,
args.h_size,
trainset.num_classes,
......@@ -85,7 +90,7 @@ def main(args):
ctx=ctx)
print(model)
params_ex_emb =[x for x in model.collect_params().values()
if x.grad_req != 'null' and x.shape[0] != trainset.num_vocabs]
if x.grad_req != 'null' and x.shape[0] != trainset.vocab_size]
params_emb = list(model.embedding.collect_params().values())
for p in params_emb:
p.lr_mult = 0.1
......
......@@ -113,9 +113,6 @@ class TreeLSTM(gluon.nn.Block):
The prediction of each node.
"""
g = batch.graph
g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
wiou = self.cell.W_iou(self.dropout(embeds))
......@@ -123,7 +120,10 @@ class TreeLSTM(gluon.nn.Block):
g.ndata['h'] = h
g.ndata['c'] = c
# propagate
dgl.prop_nodes_topo(g)
dgl.prop_nodes_topo(g,
message_func=self.cell.message_func,
reduce_func=self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func)
# compute logits
h = self.dropout(g.ndata.pop('h'))
logits = self.linear(h)
......
......@@ -9,11 +9,12 @@ import torch.optim as optim
from torch.utils.data import DataLoader
import dgl
from dgl.data.tree import SST, SSTBatch
from dgl.data.tree import SSTDataset
from tree_lstm import TreeLSTM
SSTBatch = collections.namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
def batcher(device):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
......@@ -36,24 +37,24 @@ def main(args):
if cuda:
th.cuda.set_device(args.gpu)
trainset = SST()
trainset = SSTDataset()
train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size,
collate_fn=batcher(device),
shuffle=True,
num_workers=0)
devset = SST(mode='dev')
devset = SSTDataset(mode='dev')
dev_loader = DataLoader(dataset=devset,
batch_size=100,
collate_fn=batcher(device),
shuffle=False,
num_workers=0)
testset = SST(mode='test')
testset = SSTDataset(mode='test')
test_loader = DataLoader(dataset=testset,
batch_size=100, collate_fn=batcher(device), shuffle=False, num_workers=0)
model = TreeLSTM(trainset.num_vocabs,
model = TreeLSTM(trainset.vocab_size,
args.x_size,
args.h_size,
trainset.num_classes,
......@@ -61,7 +62,7 @@ def main(args):
cell_type='childsum' if args.child_sum else 'nary',
pretrained_emb = trainset.pretrained_emb).to(device)
print(model)
params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.num_vocabs]
params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.vocab_size]
params_emb = list(model.embedding.parameters())
for p in params_ex_emb:
......
......@@ -4,7 +4,7 @@ from __future__ import absolute_import
from . import citation_graph as citegrh
from .citation_graph import CoraBinary, CitationGraphDataset
from .minigc import *
from .tree import *
from .tree import SST, SSTDataset
from .utils import *
from .sbm import SBMMixture
from .reddit import RedditDataset
......
"""Tree-structured data.
Including:
- Stanford Sentiment Treebank
"""
from __future__ import absolute_import
from collections import namedtuple, OrderedDict
from collections import OrderedDict
import networkx as nx
import numpy as np
import os
from .dgl_dataset import DGLBuiltinDataset
from .. import backend as F
from .utils import _get_dgl_url, save_graphs, save_info, load_graphs, \
load_info, deprecate_property
from ..convert import from_networkx
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
__all__ = ['SSTBatch', 'SST']
__all__ = ['SST', 'SSTDataset']
_urls = {
'sst' : 'dataset/sst.zip',
}
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
class SSTDataset(DGLBuiltinDataset):
r"""Stanford Sentiment Treebank dataset.
class SST(object):
"""Stanford Sentiment Treebank dataset.
.. deprecated:: 0.5.0
`trees` is deprecated, it is replaced by:
>>> dataset = SSTDataset()
>>> for tree in dataset:
.... # your code here
....
>>>
`num_vocabs` is deprecated, it is replaced by `vocab_size`
Each sample is the constituency tree of a sentence. The leaf nodes
represent words. The word is a int value stored in the ``x`` feature field.
......@@ -33,80 +37,145 @@ class SST(object):
Each node also has a sentiment annotation: 5 classes (very negative,
negative, neutral, positive and very positive). The sentiment label is a
int value stored in the ``y`` feature field.
Official site: http://nlp.stanford.edu/sentiment/index.html
.. note::
This dataset class is compatible with pytorch's :class:`Dataset` class.
.. note::
All the samples will be loaded and preprocessed in the memory first.
Statistics
----------
Train examples: 8,544
Dev examples: 1,101
Test examples: 2,210
Number of classes for each node: 5
Parameters
----------
mode : str, optional
Can be ``'train'``, ``'val'``, ``'test'`` and specifies which data file to use.
Should be one of ['train', 'dev', 'test', 'tiny']
Default: train
glove_embed_file : str, optional
The path to pretrained glove embedding file.
Default: None
vocab_file : str, optional
Optional vocabulary file.
Optional vocabulary file. If not given, the default vacabulary file is used.
Default: None
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
vocab : OrderedDict
Vocabulary of the dataset
trees : list
A list of DGLGraph objects
num_classes : int
Number of classes for each node
pretrained_emb: Tensor
Pretrained glove embedding with respect the vocabulary.
vocab_size : int
The size of the vocabulary
num_vocabs : int
The size of the vocabulary
Notes
-----
All the samples will be loaded and preprocessed in the memory first.
Examples
--------
>>> # get dataset
>>> train_data = SSTDataset()
>>> dev_data = SSTDataset(mode='dev')
>>> test_data = SSTDataset(mode='test')
>>> tiny_data = SSTDataset(mode='tiny')
>>>
>>> len(train_data)
8544
>>> train_data.num_classes
5
>>> glove_embed = train_data.pretrained_emb
>>> train_data.vocab_size
19536
>>> train_data[0]
DGLGraph(num_nodes=71, num_edges=70,
ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(),
dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={})
>>> for tree in train_data:
... input_ids = tree.ndata['x']
... labels = tree.ndata['y']
... mask = tree.ndata['mask']
... # your code here
>>>
"""
PAD_WORD=-1 # special pad word id
UNK_WORD=-1 # out-of-vocabulary word id
def __init__(self, mode='train', vocab_file=None):
PAD_WORD = -1 # special pad word id
UNK_WORD = -1 # out-of-vocabulary word id
def __init__(self,
mode='train',
glove_embed_file=None,
vocab_file=None,
raw_dir=None,
force_reload=False,
verbose=False):
assert mode in ['train', 'dev', 'test', 'tiny']
_url = _get_dgl_url('dataset/sst.zip')
self._glove_embed_file = glove_embed_file if mode == 'train' else None
self.mode = mode
self.dir = get_download_dir()
self.zip_file_path='{}/sst.zip'.format(self.dir)
self.pretrained_file = 'glove.840B.300d.txt' if mode == 'train' else ''
self.pretrained_emb = None
self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file
self.trees = []
self.num_classes = 5
print('Preprocessing...')
self._load()
print('Dataset creation finished. #Trees:', len(self.trees))
def _download(self):
download(_get_dgl_url(_urls['sst']), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
@retry_method_with_fix(_download)
def _load(self):
self._vocab_file = vocab_file
super(SSTDataset, self).__init__(name='sst',
url=_url,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
def process(self):
from nltk.corpus.reader import BracketParseCorpusReader
# load vocab file
self.vocab = OrderedDict()
with open(self.vocab_file, encoding='utf-8') as vf:
self._vocab = OrderedDict()
vocab_file = self._vocab_file if self._vocab_file is not None else os.path.join(self.raw_path, 'vocab.txt')
with open(vocab_file, encoding='utf-8') as vf:
for line in vf.readlines():
line = line.strip()
self.vocab[line] = len(self.vocab)
self._vocab[line] = len(self._vocab)
# filter glove
if self.pretrained_file != '' and os.path.exists(self.pretrained_file):
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
glove_emb = {}
with open(self.pretrained_file, 'r', encoding='utf-8') as pf:
with open(self._glove_embed_file, 'r', encoding='utf-8') as pf:
for line in pf.readlines():
sp = line.split(' ')
if sp[0].lower() in self.vocab:
if sp[0].lower() in self._vocab:
glove_emb[sp[0].lower()] = np.asarray([float(x) for x in sp[1:]])
files = ['{}.txt'.format(self.mode)]
corpus = BracketParseCorpusReader('{}/sst'.format(self.dir), files)
corpus = BracketParseCorpusReader(self.raw_path, files)
sents = corpus.parsed_sents(files[0])
#initialize with glove
# initialize with glove
pretrained_emb = []
fail_cnt = 0
for line in self.vocab.keys():
if self.pretrained_file != '' and os.path.exists(self.pretrained_file):
for line in self._vocab.keys():
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
if not line.lower() in glove_emb:
fail_cnt += 1
pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300)))
if self.pretrained_file != '' and os.path.exists(self.pretrained_file):
self.pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))
print('Miss word in GloVe {0:.4f}'.format(1.0*fail_cnt/len(self.pretrained_emb)))
self._pretrained_emb = None
if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))
print('Miss word in GloVe {0:.4f}'.format(1.0 * fail_cnt / len(self._pretrained_emb)))
# build trees
self._trees = []
for sent in sents:
self.trees.append(self._build_tree(sent))
self._trees.append(self._build_tree(sent))
def _build_tree(self, root):
g = nx.DiGraph()
def _rec_build(nid, node):
for child in node:
cid = g.number_of_nodes()
......@@ -115,40 +184,95 @@ class SST(object):
word = self.vocab.get(child[0].lower(), self.UNK_WORD)
g.add_node(cid, x=word, y=int(child.label()), mask=1)
else:
g.add_node(cid, x=SST.PAD_WORD, y=int(child.label()), mask=0)
g.add_node(cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0)
_rec_build(cid, child)
g.add_edge(cid, nid)
# add root
g.add_node(0, x=SST.PAD_WORD, y=int(root.label()), mask=0)
g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0)
_rec_build(0, root)
ret = from_networkx(g, node_attrs=['x', 'y', 'mask'])
return ret
def has_cache(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
ret = os.path.exists(graph_path)
if self.mode == 'train':
info_path = os.path.join(self.save_path, 'graph_info.pkl')
ret = ret and os.path.exists(info_path)
return ret
def save(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
save_graphs(graph_path, self._trees)
if self.mode == 'train':
info_path = os.path.join(self.save_path, 'info.pkl')
save_info(info_path, {'vocab': self.vocab, 'embed': self.pretrained_emb})
def load(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
self._trees = load_graphs(graph_path)[0]
info_path = os.path.join(self.save_path, 'info.pkl')
if os.path.exists(info_path):
info = load_info(info_path)
self._vocab = info['vocab']
self._pretrained_emb = info['embed']
@property
def trees(self):
deprecate_property('dataset.trees', '[dataset[i] for i in len(dataset)]')
return self._trees
@property
def vocab(self):
r""" Vocabulary
Returns
-------
OrderedDict
"""
return self._vocab
@property
def pretrained_emb(self):
r"""Pre-trained word embedding, if given."""
return self._pretrained_emb
def __getitem__(self, idx):
"""Get the tree with index idx.
r""" Get graph by index
Parameters
----------
idx : int
Tree index.
Returns
-------
dgl.DGLGraph
Tree.
graph structure, word id for each node, node labels and masks
- ndata['x']: word id of the node
- ndata['y']: label of the node
- ndata['mask']: 1 if the node is a leaf, otherwise 0
"""
return self.trees[idx]
return self._trees[idx]
def __len__(self):
"""Get the number of trees in the dataset.
Returns
-------
int
Number of trees.
"""
return len(self.trees)
r"""Number of graphs in the dataset."""
return len(self._trees)
@property
def num_vocabs(self):
return len(self.vocab)
deprecate_property('dataset.num_vocabs', 'dataset.vocab_size')
return self.vocab_size
@property
def vocab_size(self):
r"""Vocabulary size."""
return len(self._vocab)
@property
def num_classes(self):
r"""Number of classes for each node."""
return 5
SST = SSTDataset
......@@ -48,15 +48,19 @@ Tutorial: Tree-LSTM in DGL
# at the first one.
#
from collections import namedtuple
import dgl
from dgl.data.tree import SST
from dgl.data import SSTBatch
from dgl.data.tree import SSTDataset
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field.
trainset = SST(mode='tiny') # the "tiny" set has only five trees
trainset = SSTDataset(mode='tiny') # the "tiny" set has only five trees
tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment