Unverified Commit 1bbc885b authored by Da Zheng's avatar Da Zheng Committed by GitHub
Browse files

move batcher to examples. (#269)

* move pytorch code to examples.

* fix.

* fix tutorial
parent 7aa494b3
......@@ -12,6 +12,15 @@ from dgl.data.tree import SST
from tree_lstm import TreeLSTM
def batcher(dev):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].to(device),
wordid=batch_trees.ndata['x'].to(device),
label=batch_trees.ndata['y'].to(device))
return batcher_dev
def main(args):
np.random.seed(args.seed)
th.manual_seed(args.seed)
......@@ -28,19 +37,19 @@ def main(args):
trainset = SST()
train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size,
collate_fn=SST.batcher(device),
collate_fn=batcher(device),
shuffle=True,
num_workers=0)
devset = SST(mode='dev')
dev_loader = DataLoader(dataset=devset,
batch_size=100,
collate_fn=SST.batcher(device),
collate_fn=batcher(device),
shuffle=False,
num_workers=0)
testset = SST(mode='test')
test_loader = DataLoader(dataset=testset,
batch_size=100, collate_fn=SST.batcher(device), shuffle=False, num_workers=0)
batch_size=100, collate_fn=batcher(device), shuffle=False, num_workers=0)
model = TreeLSTM(trainset.num_vocabs,
args.x_size,
......
......@@ -148,13 +148,3 @@ class SST(object):
@property
def num_vocabs(self):
return len(self.vocab)
@staticmethod
def batcher(device):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].to(device),
wordid=batch_trees.ndata['x'].to(device),
label=batch_trees.ndata['y'].to(device))
return batcher_dev
......@@ -47,6 +47,7 @@ Tree LSTM DGL Tutorial
import dgl
from dgl.data.tree import SST
from dgl.data import SSTBatch
# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is a int value stored in the "x" field.
......@@ -335,9 +336,18 @@ optimizer = th.optim.Adagrad(model.parameters(),
lr=lr,
weight_decay=weight_decay)
def batcher(dev):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].to(device),
wordid=batch_trees.ndata['x'].to(device),
label=batch_trees.ndata['y'].to(device))
return batcher_dev
train_loader = DataLoader(dataset=tiny_sst,
batch_size=5,
collate_fn=SST.batcher(device),
collate_fn=batcher(device),
shuffle=False,
num_workers=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment