"...text-generation-inference.git" did not exist on "211b54ac41cae9a369f3d74bd6cc666ff4a0c526"
Unverified Commit 001d7937 authored by Xiangkun Hu's avatar Xiangkun Hu Committed by GitHub
Browse files

[Dataset] SSTDataset (#1918)

* PPIDataset

* Revert "PPIDataset"

This reverts commit 264bd0c960cfa698a7bb946dad132bf52c2d0c8a.

* SSTDataset

* Update tree.py
parent 73b9c6f1
...@@ -3,6 +3,7 @@ import time ...@@ -3,6 +3,7 @@ import time
import warnings import warnings
import zipfile import zipfile
import os import os
import collections
os.environ['DGLBACKEND'] = 'mxnet' os.environ['DGLBACKEND'] = 'mxnet'
os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round' os.environ['MXNET_GPU_MEM_POOL_TYPE'] = 'Round'
...@@ -16,10 +17,12 @@ import dgl.data as data ...@@ -16,10 +17,12 @@ import dgl.data as data
from tree_lstm import TreeLSTM from tree_lstm import TreeLSTM
SSTBatch = collections.namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
def batcher(ctx): def batcher(ctx):
def batcher_dev(batch): def batcher_dev(batch):
batch_trees = dgl.batch(batch) batch_trees = dgl.batch(batch)
return data.SSTBatch(graph=batch_trees, return SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].as_in_context(ctx), mask=batch_trees.ndata['mask'].as_in_context(ctx),
wordid=batch_trees.ndata['x'].as_in_context(ctx), wordid=batch_trees.ndata['x'].as_in_context(ctx),
label=batch_trees.ndata['y'].as_in_context(ctx)) label=batch_trees.ndata['y'].as_in_context(ctx))
...@@ -52,30 +55,32 @@ def main(args): ...@@ -52,30 +55,32 @@ def main(args):
else: else:
print('Requested GPU id {} was not found. Defaulting to CPU implementation'.format(args.gpu)) print('Requested GPU id {} was not found. Defaulting to CPU implementation'.format(args.gpu))
ctx = mx.cpu() ctx = mx.cpu()
else:
ctx = mx.cpu()
if args.use_glove: if args.use_glove:
prepare_glove() prepare_glove()
trainset = data.SST() trainset = data.SSTDataset()
train_loader = gluon.data.DataLoader(dataset=trainset, train_loader = gluon.data.DataLoader(dataset=trainset,
batch_size=args.batch_size, batch_size=args.batch_size,
batchify_fn=batcher(ctx), batchify_fn=batcher(ctx),
shuffle=True, shuffle=True,
num_workers=0) num_workers=0)
devset = data.SST(mode='dev') devset = data.SSTDataset(mode='dev')
dev_loader = gluon.data.DataLoader(dataset=devset, dev_loader = gluon.data.DataLoader(dataset=devset,
batch_size=100, batch_size=100,
batchify_fn=batcher(ctx), batchify_fn=batcher(ctx),
shuffle=True, shuffle=True,
num_workers=0) num_workers=0)
testset = data.SST(mode='test') testset = data.SSTDataset(mode='test')
test_loader = gluon.data.DataLoader(dataset=testset, test_loader = gluon.data.DataLoader(dataset=testset,
batch_size=100, batch_size=100,
batchify_fn=batcher(ctx), batchify_fn=batcher(ctx),
shuffle=False, num_workers=0) shuffle=False, num_workers=0)
model = TreeLSTM(trainset.num_vocabs, model = TreeLSTM(trainset.vocab_size,
args.x_size, args.x_size,
args.h_size, args.h_size,
trainset.num_classes, trainset.num_classes,
...@@ -85,7 +90,7 @@ def main(args): ...@@ -85,7 +90,7 @@ def main(args):
ctx=ctx) ctx=ctx)
print(model) print(model)
params_ex_emb =[x for x in model.collect_params().values() params_ex_emb =[x for x in model.collect_params().values()
if x.grad_req != 'null' and x.shape[0] != trainset.num_vocabs] if x.grad_req != 'null' and x.shape[0] != trainset.vocab_size]
params_emb = list(model.embedding.collect_params().values()) params_emb = list(model.embedding.collect_params().values())
for p in params_emb: for p in params_emb:
p.lr_mult = 0.1 p.lr_mult = 0.1
......
...@@ -113,9 +113,6 @@ class TreeLSTM(gluon.nn.Block): ...@@ -113,9 +113,6 @@ class TreeLSTM(gluon.nn.Block):
The prediction of each node. The prediction of each node.
""" """
g = batch.graph g = batch.graph
g.register_message_func(self.cell.message_func)
g.register_reduce_func(self.cell.reduce_func)
g.register_apply_node_func(self.cell.apply_node_func)
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid * batch.mask) embeds = self.embedding(batch.wordid * batch.mask)
wiou = self.cell.W_iou(self.dropout(embeds)) wiou = self.cell.W_iou(self.dropout(embeds))
...@@ -123,7 +120,10 @@ class TreeLSTM(gluon.nn.Block): ...@@ -123,7 +120,10 @@ class TreeLSTM(gluon.nn.Block):
g.ndata['h'] = h g.ndata['h'] = h
g.ndata['c'] = c g.ndata['c'] = c
# propagate # propagate
dgl.prop_nodes_topo(g) dgl.prop_nodes_topo(g,
message_func=self.cell.message_func,
reduce_func=self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func)
# compute logits # compute logits
h = self.dropout(g.ndata.pop('h')) h = self.dropout(g.ndata.pop('h'))
logits = self.linear(h) logits = self.linear(h)
......
...@@ -9,11 +9,12 @@ import torch.optim as optim ...@@ -9,11 +9,12 @@ import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl import dgl
from dgl.data.tree import SST, SSTBatch from dgl.data.tree import SSTDataset
from tree_lstm import TreeLSTM from tree_lstm import TreeLSTM
SSTBatch = collections.namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label']) SSTBatch = collections.namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
def batcher(device): def batcher(device):
def batcher_dev(batch): def batcher_dev(batch):
batch_trees = dgl.batch(batch) batch_trees = dgl.batch(batch)
...@@ -36,24 +37,24 @@ def main(args): ...@@ -36,24 +37,24 @@ def main(args):
if cuda: if cuda:
th.cuda.set_device(args.gpu) th.cuda.set_device(args.gpu)
trainset = SST() trainset = SSTDataset()
train_loader = DataLoader(dataset=trainset, train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=batcher(device), collate_fn=batcher(device),
shuffle=True, shuffle=True,
num_workers=0) num_workers=0)
devset = SST(mode='dev') devset = SSTDataset(mode='dev')
dev_loader = DataLoader(dataset=devset, dev_loader = DataLoader(dataset=devset,
batch_size=100, batch_size=100,
collate_fn=batcher(device), collate_fn=batcher(device),
shuffle=False, shuffle=False,
num_workers=0) num_workers=0)
testset = SST(mode='test') testset = SSTDataset(mode='test')
test_loader = DataLoader(dataset=testset, test_loader = DataLoader(dataset=testset,
batch_size=100, collate_fn=batcher(device), shuffle=False, num_workers=0) batch_size=100, collate_fn=batcher(device), shuffle=False, num_workers=0)
model = TreeLSTM(trainset.num_vocabs, model = TreeLSTM(trainset.vocab_size,
args.x_size, args.x_size,
args.h_size, args.h_size,
trainset.num_classes, trainset.num_classes,
...@@ -61,7 +62,7 @@ def main(args): ...@@ -61,7 +62,7 @@ def main(args):
cell_type='childsum' if args.child_sum else 'nary', cell_type='childsum' if args.child_sum else 'nary',
pretrained_emb = trainset.pretrained_emb).to(device) pretrained_emb = trainset.pretrained_emb).to(device)
print(model) print(model)
params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.num_vocabs] params_ex_emb =[x for x in list(model.parameters()) if x.requires_grad and x.size(0)!=trainset.vocab_size]
params_emb = list(model.embedding.parameters()) params_emb = list(model.embedding.parameters())
for p in params_ex_emb: for p in params_ex_emb:
......
...@@ -4,7 +4,7 @@ from __future__ import absolute_import ...@@ -4,7 +4,7 @@ from __future__ import absolute_import
from . import citation_graph as citegrh from . import citation_graph as citegrh
from .citation_graph import CoraBinary, CitationGraphDataset from .citation_graph import CoraBinary, CitationGraphDataset
from .minigc import * from .minigc import *
from .tree import * from .tree import SST, SSTDataset
from .utils import * from .utils import *
from .sbm import SBMMixture from .sbm import SBMMixture
from .reddit import RedditDataset from .reddit import RedditDataset
......
"""Tree-structured data. """Tree-structured data.
Including: Including:
- Stanford Sentiment Treebank - Stanford Sentiment Treebank
""" """
from __future__ import absolute_import from __future__ import absolute_import
from collections import namedtuple, OrderedDict from collections import OrderedDict
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import os import os
from .dgl_dataset import DGLBuiltinDataset
from .. import backend as F from .. import backend as F
from .utils import _get_dgl_url, save_graphs, save_info, load_graphs, \
load_info, deprecate_property
from ..convert import from_networkx from ..convert import from_networkx
from .utils import download, extract_archive, get_download_dir, _get_dgl_url
from ..utils import retry_method_with_fix
__all__ = ['SSTBatch', 'SST'] __all__ = ['SST', 'SSTDataset']
_urls = {
'sst' : 'dataset/sst.zip',
}
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label']) class SSTDataset(DGLBuiltinDataset):
r"""Stanford Sentiment Treebank dataset.
class SST(object): .. deprecated:: 0.5.0
"""Stanford Sentiment Treebank dataset. `trees` is deprecated, it is replaced by:
>>> dataset = SSTDataset()
>>> for tree in dataset:
.... # your code here
....
>>>
`num_vocabs` is deprecated, it is replaced by `vocab_size`
Each sample is the constituency tree of a sentence. The leaf nodes Each sample is the constituency tree of a sentence. The leaf nodes
represent words. The word is a int value stored in the ``x`` feature field. represent words. The word is a int value stored in the ``x`` feature field.
...@@ -33,80 +37,145 @@ class SST(object): ...@@ -33,80 +37,145 @@ class SST(object):
Each node also has a sentiment annotation: 5 classes (very negative, Each node also has a sentiment annotation: 5 classes (very negative,
negative, neutral, positive and very positive). The sentiment label is a negative, neutral, positive and very positive). The sentiment label is a
int value stored in the ``y`` feature field. int value stored in the ``y`` feature field.
Official site: http://nlp.stanford.edu/sentiment/index.html
.. note:: Statistics
This dataset class is compatible with pytorch's :class:`Dataset` class. ----------
Train examples: 8,544
.. note:: Dev examples: 1,101
All the samples will be loaded and preprocessed in the memory first. Test examples: 2,210
Number of classes for each node: 5
Parameters Parameters
---------- ----------
mode : str, optional mode : str, optional
Can be ``'train'``, ``'val'``, ``'test'`` and specifies which data file to use. Should be one of ['train', 'dev', 'test', 'tiny']
Default: train
glove_embed_file : str, optional
The path to pretrained glove embedding file.
Default: None
vocab_file : str, optional vocab_file : str, optional
Optional vocabulary file. Optional vocabulary file. If not given, the default vacabulary file is used.
Default: None
raw_dir : str
Raw file directory to download/contains the input data directory.
Default: ~/.dgl/
force_reload : bool
Whether to reload the dataset. Default: False
verbose: bool
Whether to print out progress information. Default: True.
Attributes
----------
vocab : OrderedDict
Vocabulary of the dataset
trees : list
A list of DGLGraph objects
num_classes : int
Number of classes for each node
pretrained_emb: Tensor
Pretrained glove embedding with respect the vocabulary.
vocab_size : int
The size of the vocabulary
num_vocabs : int
The size of the vocabulary
Notes
-----
All the samples will be loaded and preprocessed in the memory first.
Examples
--------
>>> # get dataset
>>> train_data = SSTDataset()
>>> dev_data = SSTDataset(mode='dev')
>>> test_data = SSTDataset(mode='test')
>>> tiny_data = SSTDataset(mode='tiny')
>>>
>>> len(train_data)
8544
>>> train_data.num_classes
5
>>> glove_embed = train_data.pretrained_emb
>>> train_data.vocab_size
19536
>>> train_data[0]
DGLGraph(num_nodes=71, num_edges=70,
ndata_schemes={'x': Scheme(shape=(), dtype=torch.int64), 'y': Scheme(shape=(),
dtype=torch.int64), 'mask': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={})
>>> for tree in train_data:
... input_ids = tree.ndata['x']
... labels = tree.ndata['y']
... mask = tree.ndata['mask']
... # your code here
>>>
""" """
PAD_WORD=-1 # special pad word id
UNK_WORD=-1 # out-of-vocabulary word id PAD_WORD = -1 # special pad word id
def __init__(self, mode='train', vocab_file=None): UNK_WORD = -1 # out-of-vocabulary word id
def __init__(self,
mode='train',
glove_embed_file=None,
vocab_file=None,
raw_dir=None,
force_reload=False,
verbose=False):
assert mode in ['train', 'dev', 'test', 'tiny']
_url = _get_dgl_url('dataset/sst.zip')
self._glove_embed_file = glove_embed_file if mode == 'train' else None
self.mode = mode self.mode = mode
self.dir = get_download_dir() self._vocab_file = vocab_file
self.zip_file_path='{}/sst.zip'.format(self.dir) super(SSTDataset, self).__init__(name='sst',
self.pretrained_file = 'glove.840B.300d.txt' if mode == 'train' else '' url=_url,
self.pretrained_emb = None raw_dir=raw_dir,
self.vocab_file = '{}/sst/vocab.txt'.format(self.dir) if vocab_file is None else vocab_file force_reload=force_reload,
self.trees = [] verbose=verbose)
self.num_classes = 5
print('Preprocessing...') def process(self):
self._load()
print('Dataset creation finished. #Trees:', len(self.trees))
def _download(self):
download(_get_dgl_url(_urls['sst']), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/sst'.format(self.dir))
@retry_method_with_fix(_download)
def _load(self):
from nltk.corpus.reader import BracketParseCorpusReader from nltk.corpus.reader import BracketParseCorpusReader
# load vocab file # load vocab file
self.vocab = OrderedDict() self._vocab = OrderedDict()
with open(self.vocab_file, encoding='utf-8') as vf: vocab_file = self._vocab_file if self._vocab_file is not None else os.path.join(self.raw_path, 'vocab.txt')
with open(vocab_file, encoding='utf-8') as vf:
for line in vf.readlines(): for line in vf.readlines():
line = line.strip() line = line.strip()
self.vocab[line] = len(self.vocab) self._vocab[line] = len(self._vocab)
# filter glove # filter glove
if self.pretrained_file != '' and os.path.exists(self.pretrained_file): if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
glove_emb = {} glove_emb = {}
with open(self.pretrained_file, 'r', encoding='utf-8') as pf: with open(self._glove_embed_file, 'r', encoding='utf-8') as pf:
for line in pf.readlines(): for line in pf.readlines():
sp = line.split(' ') sp = line.split(' ')
if sp[0].lower() in self.vocab: if sp[0].lower() in self._vocab:
glove_emb[sp[0].lower()] = np.asarray([float(x) for x in sp[1:]]) glove_emb[sp[0].lower()] = np.asarray([float(x) for x in sp[1:]])
files = ['{}.txt'.format(self.mode)] files = ['{}.txt'.format(self.mode)]
corpus = BracketParseCorpusReader('{}/sst'.format(self.dir), files) corpus = BracketParseCorpusReader(self.raw_path, files)
sents = corpus.parsed_sents(files[0]) sents = corpus.parsed_sents(files[0])
#initialize with glove # initialize with glove
pretrained_emb = [] pretrained_emb = []
fail_cnt = 0 fail_cnt = 0
for line in self.vocab.keys(): for line in self._vocab.keys():
if self.pretrained_file != '' and os.path.exists(self.pretrained_file): if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
if not line.lower() in glove_emb: if not line.lower() in glove_emb:
fail_cnt += 1 fail_cnt += 1
pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300))) pretrained_emb.append(glove_emb.get(line.lower(), np.random.uniform(-0.05, 0.05, 300)))
if self.pretrained_file != '' and os.path.exists(self.pretrained_file): self._pretrained_emb = None
self.pretrained_emb = F.tensor(np.stack(pretrained_emb, 0)) if self._glove_embed_file is not None and os.path.exists(self._glove_embed_file):
print('Miss word in GloVe {0:.4f}'.format(1.0*fail_cnt/len(self.pretrained_emb))) self._pretrained_emb = F.tensor(np.stack(pretrained_emb, 0))
print('Miss word in GloVe {0:.4f}'.format(1.0 * fail_cnt / len(self._pretrained_emb)))
# build trees # build trees
self._trees = []
for sent in sents: for sent in sents:
self.trees.append(self._build_tree(sent)) self._trees.append(self._build_tree(sent))
def _build_tree(self, root): def _build_tree(self, root):
g = nx.DiGraph() g = nx.DiGraph()
def _rec_build(nid, node): def _rec_build(nid, node):
for child in node: for child in node:
cid = g.number_of_nodes() cid = g.number_of_nodes()
...@@ -115,40 +184,95 @@ class SST(object): ...@@ -115,40 +184,95 @@ class SST(object):
word = self.vocab.get(child[0].lower(), self.UNK_WORD) word = self.vocab.get(child[0].lower(), self.UNK_WORD)
g.add_node(cid, x=word, y=int(child.label()), mask=1) g.add_node(cid, x=word, y=int(child.label()), mask=1)
else: else:
g.add_node(cid, x=SST.PAD_WORD, y=int(child.label()), mask=0) g.add_node(cid, x=SSTDataset.PAD_WORD, y=int(child.label()), mask=0)
_rec_build(cid, child) _rec_build(cid, child)
g.add_edge(cid, nid) g.add_edge(cid, nid)
# add root # add root
g.add_node(0, x=SST.PAD_WORD, y=int(root.label()), mask=0) g.add_node(0, x=SSTDataset.PAD_WORD, y=int(root.label()), mask=0)
_rec_build(0, root) _rec_build(0, root)
ret = from_networkx(g, node_attrs=['x', 'y', 'mask']) ret = from_networkx(g, node_attrs=['x', 'y', 'mask'])
return ret return ret
def has_cache(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
ret = os.path.exists(graph_path)
if self.mode == 'train':
info_path = os.path.join(self.save_path, 'graph_info.pkl')
ret = ret and os.path.exists(info_path)
return ret
def save(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
save_graphs(graph_path, self._trees)
if self.mode == 'train':
info_path = os.path.join(self.save_path, 'info.pkl')
save_info(info_path, {'vocab': self.vocab, 'embed': self.pretrained_emb})
def load(self):
graph_path = os.path.join(self.save_path, self.mode + '_dgl_graph.bin')
self._trees = load_graphs(graph_path)[0]
info_path = os.path.join(self.save_path, 'info.pkl')
if os.path.exists(info_path):
info = load_info(info_path)
self._vocab = info['vocab']
self._pretrained_emb = info['embed']
@property
def trees(self):
deprecate_property('dataset.trees', '[dataset[i] for i in len(dataset)]')
return self._trees
@property
def vocab(self):
r""" Vocabulary
Returns
-------
OrderedDict
"""
return self._vocab
@property
def pretrained_emb(self):
r"""Pre-trained word embedding, if given."""
return self._pretrained_emb
def __getitem__(self, idx): def __getitem__(self, idx):
"""Get the tree with index idx. r""" Get graph by index
Parameters Parameters
---------- ----------
idx : int idx : int
Tree index.
Returns Returns
------- -------
dgl.DGLGraph dgl.DGLGraph
Tree. graph structure, word id for each node, node labels and masks
- ndata['x']: word id of the node
- ndata['y']: label of the node
- ndata['mask']: 1 if the node is a leaf, otherwise 0
""" """
return self.trees[idx] return self._trees[idx]
def __len__(self): def __len__(self):
"""Get the number of trees in the dataset. r"""Number of graphs in the dataset."""
return len(self._trees)
Returns
-------
int
Number of trees.
"""
return len(self.trees)
@property @property
def num_vocabs(self): def num_vocabs(self):
return len(self.vocab) deprecate_property('dataset.num_vocabs', 'dataset.vocab_size')
return self.vocab_size
@property
def vocab_size(self):
r"""Vocabulary size."""
return len(self._vocab)
@property
def num_classes(self):
r"""Number of classes for each node."""
return 5
SST = SSTDataset
...@@ -48,15 +48,19 @@ Tutorial: Tree-LSTM in DGL ...@@ -48,15 +48,19 @@ Tutorial: Tree-LSTM in DGL
# at the first one. # at the first one.
# #
from collections import namedtuple
import dgl import dgl
from dgl.data.tree import SST from dgl.data.tree import SSTDataset
from dgl.data import SSTBatch
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
# Each sample in the dataset is a constituency tree. The leaf nodes # Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field. # represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment # The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field. # label is stored in the "y" feature field.
trainset = SST(mode='tiny') # the "tiny" set has only five trees trainset = SSTDataset(mode='tiny') # the "tiny" set has only five trees
tiny_sst = trainset.trees tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes num_classes = trainset.num_classes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment