"references/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "80cfc2c8c09314696e72b897928473e05a325f1f"
Unverified Commit 21255b65 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Bugfix] tolist and dependencies in `dgl.data` (#239)

* change Index.tolist -> Index.tonumpy; fix bug in traversal; remove dependencies in data

* fix import

* fix __all__ and some docstring
parent eafcb7e7
...@@ -8,7 +8,7 @@ import torch.optim as optim ...@@ -8,7 +8,7 @@ import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl import dgl
import dgl.data as data from dgl.data.tree import SST
from tree_lstm import TreeLSTM from tree_lstm import TreeLSTM
...@@ -25,22 +25,22 @@ def main(args): ...@@ -25,22 +25,22 @@ def main(args):
if cuda: if cuda:
th.cuda.set_device(args.gpu) th.cuda.set_device(args.gpu)
trainset = data.SST() trainset = SST()
train_loader = DataLoader(dataset=trainset, train_loader = DataLoader(dataset=trainset,
batch_size=args.batch_size, batch_size=args.batch_size,
collate_fn=data.SST.batcher(device), collate_fn=SST.batcher(device),
shuffle=True, shuffle=True,
num_workers=0) num_workers=0)
devset = data.SST(mode='dev') devset = SST(mode='dev')
dev_loader = DataLoader(dataset=devset, dev_loader = DataLoader(dataset=devset,
batch_size=100, batch_size=100,
collate_fn=data.SST.batcher(device), collate_fn=SST.batcher(device),
shuffle=False, shuffle=False,
num_workers=0) num_workers=0)
testset = data.SST(mode='test') testset = SST(mode='test')
test_loader = DataLoader(dataset=testset, test_loader = DataLoader(dataset=testset,
batch_size=100, collate_fn=data.SST.batcher(device), shuffle=False, num_workers=0) batch_size=100, collate_fn=SST.batcher(device), shuffle=False, num_workers=0)
model = TreeLSTM(trainset.num_vocabs, model = TreeLSTM(trainset.num_vocabs,
args.x_size, args.x_size,
......
"""Dataset for stochastic block model."""
import math import math
import os import os
import pickle import pickle
......
...@@ -6,8 +6,6 @@ Including: ...@@ -6,8 +6,6 @@ Including:
from __future__ import absolute_import from __future__ import absolute_import
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
from nltk.tree import Tree
from nltk.corpus.reader import BracketParseCorpusReader
import networkx as nx import networkx as nx
import numpy as np import numpy as np
...@@ -16,6 +14,8 @@ import dgl ...@@ -16,6 +14,8 @@ import dgl
import dgl.backend as F import dgl.backend as F
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
__all__ = ['SSTBatch', 'SST']
_urls = { _urls = {
'sst' : 'dataset/sst.zip', 'sst' : 'dataset/sst.zip',
} }
...@@ -63,6 +63,7 @@ class SST(object): ...@@ -63,6 +63,7 @@ class SST(object):
print('Dataset creation finished. #Trees:', len(self.trees)) print('Dataset creation finished. #Trees:', len(self.trees))
def _load(self): def _load(self):
from nltk.corpus.reader import BracketParseCorpusReader
# load vocab file # load vocab file
self.vocab = OrderedDict() self.vocab = OrderedDict()
with open(self.vocab_file, encoding='utf-8') as vf: with open(self.vocab_file, encoding='utf-8') as vf:
......
...@@ -13,6 +13,8 @@ except ImportError: ...@@ -13,6 +13,8 @@ except ImportError:
pass pass
requests = requests_failed_to_import requests = requests_failed_to_import
__all__ = ['download', 'check_sha1', 'extract_archive', 'get_download_dir']
def _get_dgl_url(file_url): def _get_dgl_url(file_url):
"""Get DGL online url for download.""" """Get DGL online url for download."""
dgl_repo_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/' dgl_repo_url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/'
......
...@@ -765,7 +765,7 @@ class FrameRef(MutableMapping): ...@@ -765,7 +765,7 @@ class FrameRef(MutableMapping):
if isinstance(query, slice): if isinstance(query, slice):
query = range(query.start, query.stop) query = range(query.start, query.stop)
else: else:
query = query.tolist() query = query.tonumpy()
if isinstance(self._index_data, slice): if isinstance(self._index_data, slice):
self._index_data = range(self._index_data.start, self._index_data.stop) self._index_data = range(self._index_data.start, self._index_data.stop)
...@@ -861,51 +861,3 @@ def frame_like(other, num_rows): ...@@ -861,51 +861,3 @@ def frame_like(other, num_rows):
# now supports non-exist columns. # now supports non-exist columns.
newf._initializers = other._initializers newf._initializers = other._initializers
return newf return newf
def merge_frames(frames, indices, max_index, reduce_func):
"""Merge a list of frames.
The result frame contains `max_index` number of rows. For each frame in
the given list, its row is merged as follows:
merged[indices[i][row]] += frames[i][row]
Parameters
----------
frames : iterator of dgl.frame.FrameRef
A list of frames to be merged.
indices : iterator of dgl.utils.Index
The indices of the frame rows.
reduce_func : str
The reduce function (only 'sum' is supported currently)
Returns
-------
merged : FrameRef
The merged frame.
"""
# TODO(minjie)
assert False, 'Buggy code, disabled for now.'
assert reduce_func == 'sum'
assert len(frames) > 0
schemes = frames[0].schemes
# create an adj to merge
# row index is equal to the concatenation of all the indices.
row = sum([idx.tolist() for idx in indices], [])
col = list(range(len(row)))
n = max_index
m = len(row)
row = F.unsqueeze(F.tensor(row, dtype=F.int64), 0)
col = F.unsqueeze(F.tensor(col, dtype=F.int64), 0)
idx = F.cat([row, col], dim=0)
dat = F.ones((m,))
adjmat = F.sparse_tensor(idx, dat, [n, m])
ctx_adjmat = utils.CtxCachedObject(lambda ctx: F.to_context(adjmat, ctx))
merged = {}
for key in schemes:
# the rhs of the spmv is the concatenation of all the frame columns
feats = F.pack([fr[key] for fr in frames])
merged_feats = F.spmm(ctx_adjmat.get(F.get_context(feats)), feats)
merged[key] = merged_feats
merged = FrameRef(Frame(merged))
return merged
...@@ -8,7 +8,7 @@ from collections import defaultdict ...@@ -8,7 +8,7 @@ from collections import defaultdict
import dgl import dgl
from .base import ALL, is_all, DGLError, dgl_warning from .base import ALL, is_all, DGLError, dgl_warning
from . import backend as F from . import backend as F
from .frame import FrameRef, Frame, merge_frames from .frame import FrameRef, Frame
from .graph_index import GraphIndex, create_graph_index from .graph_index import GraphIndex, create_graph_index
from .runtime import ir, scheduler, Runtime from .runtime import ir, scheduler, Runtime
from . import utils from . import utils
......
...@@ -168,7 +168,7 @@ def _process_buckets(buckets): ...@@ -168,7 +168,7 @@ def _process_buckets(buckets):
msg_ids = [utils.toindex(msg_id) for msg_id in msg_ids] msg_ids = [utils.toindex(msg_id) for msg_id in msg_ids]
# handle zero deg # handle zero deg
degs = degs.tolist() degs = degs.tonumpy()
if degs[-1] == 0: if degs[-1] == 0:
degs = degs[:-1] degs = degs[:-1]
zero_deg_nodes = dsts[-1] zero_deg_nodes = dsts[-1]
......
...@@ -44,7 +44,7 @@ def bfs_nodes_generator(graph, source, reversed=False): ...@@ -44,7 +44,7 @@ def bfs_nodes_generator(graph, source, reversed=False):
ret = _CAPI_DGLBFSNodes(ghandle, source, reversed) ret = _CAPI_DGLBFSNodes(ghandle, source, reversed)
all_nodes = utils.toindex(ret(0)).tousertensor() all_nodes = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tousertensor().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
node_frontiers = F.split(all_nodes, sections, dim=0) node_frontiers = F.split(all_nodes, sections, dim=0)
return node_frontiers return node_frontiers
...@@ -84,7 +84,7 @@ def bfs_edges_generator(graph, source, reversed=False): ...@@ -84,7 +84,7 @@ def bfs_edges_generator(graph, source, reversed=False):
ret = _CAPI_DGLBFSEdges(ghandle, source, reversed) ret = _CAPI_DGLBFSEdges(ghandle, source, reversed)
all_edges = utils.toindex(ret(0)).tousertensor() all_edges = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tousertensor().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
edge_frontiers = F.split(all_edges, sections, dim=0) edge_frontiers = F.split(all_edges, sections, dim=0)
return edge_frontiers return edge_frontiers
...@@ -120,7 +120,7 @@ def topological_nodes_generator(graph, reversed=False): ...@@ -120,7 +120,7 @@ def topological_nodes_generator(graph, reversed=False):
ret = _CAPI_DGLTopologicalNodes(ghandle, reversed) ret = _CAPI_DGLTopologicalNodes(ghandle, reversed)
all_nodes = utils.toindex(ret(0)).tousertensor() all_nodes = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tousertensor().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_nodes, sections, dim=0) return F.split(all_nodes, sections, dim=0)
def dfs_edges_generator(graph, source, reversed=False): def dfs_edges_generator(graph, source, reversed=False):
...@@ -165,7 +165,7 @@ def dfs_edges_generator(graph, source, reversed=False): ...@@ -165,7 +165,7 @@ def dfs_edges_generator(graph, source, reversed=False):
ret = _CAPI_DGLDFSEdges(ghandle, source, reversed) ret = _CAPI_DGLDFSEdges(ghandle, source, reversed)
all_edges = utils.toindex(ret(0)).tousertensor() all_edges = utils.toindex(ret(0)).tousertensor()
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
sections = utils.toindex(ret(1)).tousertensor().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_edges, sections, dim=0) return F.split(all_edges, sections, dim=0)
def dfs_labeled_edges_generator( def dfs_labeled_edges_generator(
...@@ -244,11 +244,11 @@ def dfs_labeled_edges_generator( ...@@ -244,11 +244,11 @@ def dfs_labeled_edges_generator(
# TODO(minjie): how to support directly creating python list # TODO(minjie): how to support directly creating python list
if return_labels: if return_labels:
all_labels = utils.toindex(ret(1)).tousertensor() all_labels = utils.toindex(ret(1)).tousertensor()
sections = utils.toindex(ret(2)).tousertensor().tolist() sections = utils.toindex(ret(2)).tonumpy().tolist()
return (F.split(all_edges, sections, dim=0), return (F.split(all_edges, sections, dim=0),
F.split(all_labels, sections, dim=0)) F.split(all_labels, sections, dim=0))
else: else:
sections = utils.toindex(ret(1)).tousertensor().tolist() sections = utils.toindex(ret(1)).tonumpy().tolist()
return F.split(all_edges, sections, dim=0) return F.split(all_edges, sections, dim=0)
_init_api("dgl.traversal") _init_api("dgl.traversal")
...@@ -15,24 +15,24 @@ class Index(object): ...@@ -15,24 +15,24 @@ class Index(object):
self._initialize_data(data) self._initialize_data(data)
def _initialize_data(self, data): def _initialize_data(self, data):
self._list_data = None # a numpy type data or a slice self._pydata = None # a numpy type data or a slice
self._user_tensor_data = dict() # dictionary of user tensors self._user_tensor_data = dict() # dictionary of user tensors
self._dgl_tensor_data = None # a dgl ndarray self._dgl_tensor_data = None # a dgl ndarray
self._dispatch(data) self._dispatch(data)
def __iter__(self): def __iter__(self):
for i in self.tolist(): for i in self.tonumpy():
yield int(i) yield int(i)
def __len__(self): def __len__(self):
if self._list_data is not None and isinstance(self._list_data, slice): if self._pydata is not None and isinstance(self._pydata, slice):
slc = self._list_data slc = self._pydata
if slc.step is None: if slc.step is None:
return slc.stop - slc.start return slc.stop - slc.start
else: else:
return (slc.stop - slc.start) // slc.step return (slc.stop - slc.start) // slc.step
elif self._list_data is not None: elif self._pydata is not None:
return len(self._list_data) return len(self._pydata)
elif len(self._user_tensor_data) > 0: elif len(self._user_tensor_data) > 0:
data = next(iter(self._user_tensor_data.values())) data = next(iter(self._user_tensor_data.values()))
return len(data) return len(data)
...@@ -40,7 +40,7 @@ class Index(object): ...@@ -40,7 +40,7 @@ class Index(object):
return len(self._dgl_tensor_data) return len(self._dgl_tensor_data)
def __getitem__(self, i): def __getitem__(self, i):
return int(self.tolist()[i]) return int(self.tonumpy()[i])
def _dispatch(self, data): def _dispatch(self, data):
"""Store data based on its type.""" """Store data based on its type."""
...@@ -59,35 +59,35 @@ class Index(object): ...@@ -59,35 +59,35 @@ class Index(object):
raise DGLError('Index data must be 1D int64 vector, but got: %s' % str(data)) raise DGLError('Index data must be 1D int64 vector, but got: %s' % str(data))
self._dgl_tensor_data = data self._dgl_tensor_data = data
elif isinstance(data, slice): elif isinstance(data, slice):
# save it in the _list_data temporarily; materialize it if `tolist` is called # save it in the _pydata temporarily; materialize it if `tonumpy` is called
self._list_data = data self._pydata = data
else: else:
try: try:
self._list_data = np.array([int(data)]).astype(np.int64) self._pydata = np.array([int(data)]).astype(np.int64)
except: except:
try: try:
data = np.array(data).astype(np.int64) data = np.array(data).astype(np.int64)
if data.ndim != 1: if data.ndim != 1:
raise DGLError('Index data must be 1D int64 vector,' raise DGLError('Index data must be 1D int64 vector,'
' but got: %s' % str(data)) ' but got: %s' % str(data))
self._list_data = data self._pydata = data
except: except:
raise DGLError('Error index data: %s' % str(data)) raise DGLError('Error index data: %s' % str(data))
self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self._list_data) self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self._pydata)
def tolist(self): def tonumpy(self):
"""Convert to a python-list compatible object.""" """Convert to a numpy ndarray."""
if self._list_data is None: if self._pydata is None:
if self._dgl_tensor_data is not None: if self._dgl_tensor_data is not None:
self._list_data = self._dgl_tensor_data.asnumpy() self._pydata = self._dgl_tensor_data.asnumpy()
else: else:
data = self.tousertensor() data = self.tousertensor()
self._list_data = F.zerocopy_to_numpy(data) self._pydata = F.zerocopy_to_numpy(data)
elif isinstance(self._list_data, slice): elif isinstance(self._pydata, slice):
# convert it to numpy array # convert it to numpy array
slc = self._list_data slc = self._pydata
self._list_data = np.arange(slc.start, slc.stop, slc.step).astype(np.int64) self._pydata = np.arange(slc.start, slc.stop, slc.step).astype(np.int64)
return self._list_data return self._pydata
def tousertensor(self, ctx=None): def tousertensor(self, ctx=None):
"""Convert to user tensor (defined in `backend`).""" """Convert to user tensor (defined in `backend`)."""
...@@ -100,7 +100,7 @@ class Index(object): ...@@ -100,7 +100,7 @@ class Index(object):
self._user_tensor_data[F.cpu()] = F.zerocopy_from_dlpack(dl) self._user_tensor_data[F.cpu()] = F.zerocopy_from_dlpack(dl)
else: else:
# zero copy from numpy array # zero copy from numpy array
self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self.tolist()) self._user_tensor_data[F.cpu()] = F.zerocopy_from_numpy(self.tonumpy())
if ctx not in self._user_tensor_data: if ctx not in self._user_tensor_data:
# copy from cpu to another device # copy from cpu to another device
data = next(iter(self._user_tensor_data.values())) data = next(iter(self._user_tensor_data.values()))
...@@ -117,8 +117,8 @@ class Index(object): ...@@ -117,8 +117,8 @@ class Index(object):
return self._dgl_tensor_data return self._dgl_tensor_data
def is_slice(self, start, stop, step=None): def is_slice(self, start, stop, step=None):
return (isinstance(self._list_data, slice) return (isinstance(self._pydata, slice)
and self._list_data == slice(start, stop, step)) and self._pydata == slice(start, stop, step))
def __getstate__(self): def __getstate__(self):
return self.tousertensor() return self.tousertensor()
......
...@@ -11,14 +11,14 @@ def test_edge_id(): ...@@ -11,14 +11,14 @@ def test_edge_id():
gi.add_nodes(4) gi.add_nodes(4)
gi.add_edge(0, 1) gi.add_edge(0, 1)
eid = gi.edge_id(0, 1).tolist() eid = gi.edge_id(0, 1).tonumpy()
assert len(eid) == 1 assert len(eid) == 1
assert eid[0] == 0 assert eid[0] == 0
assert gi.is_multigraph() assert gi.is_multigraph()
# multiedges # multiedges
gi.add_edge(0, 1) gi.add_edge(0, 1)
eid = gi.edge_id(0, 1).tolist() eid = gi.edge_id(0, 1).tonumpy()
assert len(eid) == 2 assert len(eid) == 2
assert eid[0] == 0 assert eid[0] == 0
assert eid[1] == 1 assert eid[1] == 1
...@@ -60,7 +60,7 @@ def test_edge_id(): ...@@ -60,7 +60,7 @@ def test_edge_id():
gi.add_nodes(4) gi.add_nodes(4)
gi.add_edge(0, 1) gi.add_edge(0, 1)
eid = gi.edge_id(0, 1).tolist() eid = gi.edge_id(0, 1).tonumpy()
assert len(eid) == 1 assert len(eid) == 1
assert eid[0] == 0 assert eid[0] == 0
......
...@@ -62,13 +62,13 @@ def check_basics(g, ig): ...@@ -62,13 +62,13 @@ def check_basics(g, ig):
for u in randv.asnumpy(): for u in randv.asnumpy():
for v in randv.asnumpy(): for v in randv.asnumpy():
if len(g.edge_id(u, v).tolist()) == 1: if len(g.edge_id(u, v)) == 1:
assert g.edge_id(u, v).tolist() == ig.edge_id(u, v).tolist() assert g.edge_id(u, v).tonumpy() == ig.edge_id(u, v).tonumpy()
assert g.has_edge_between(u, v) == ig.has_edge_between(u, v) assert g.has_edge_between(u, v) == ig.has_edge_between(u, v)
randv = utils.toindex(randv) randv = utils.toindex(randv)
ids = g.edge_ids(randv, randv)[2].tolist() ids = g.edge_ids(randv, randv)[2].tonumpy()
assert sum(ig.edge_ids(randv, randv)[2].tolist() == ids) == len(ids) assert sum(ig.edge_ids(randv, randv)[2].tonumpy() == ids) == len(ids)
assert sum(g.has_edges_between(randv, randv).tolist() == ig.has_edges_between(randv, randv).tolist()) == len(randv) assert sum(g.has_edges_between(randv, randv).tonumpy() == ig.has_edges_between(randv, randv).tonumpy()) == len(randv)
def test_basics(): def test_basics():
......
...@@ -208,7 +208,7 @@ def test_row3(): ...@@ -208,7 +208,7 @@ def test_row3():
assert f.is_contiguous() assert f.is_contiguous()
assert f.is_span_whole_column() assert f.is_span_whole_column()
assert f.num_rows == N assert f.num_rows == N
del f[th.tensor([2, 3])] del f[toindex(th.tensor([2, 3]))]
assert not f.is_contiguous() assert not f.is_contiguous()
assert not f.is_span_whole_column() assert not f.is_span_whole_column()
# delete is lazy: only reflect on the ref while the # delete is lazy: only reflect on the ref while the
......
...@@ -49,7 +49,7 @@ def test_index(): ...@@ -49,7 +49,7 @@ def test_index():
# from np data # from np data
data = np.ones((10,), dtype=np.int64) * 10 data = np.ones((10,), dtype=np.int64) * 10
idx = toindex(data) idx = toindex(data)
y1 = idx.tolist() y1 = idx.tonumpy()
y2 = idx.tousertensor().numpy() y2 = idx.tousertensor().numpy()
y3 = idx.todgltensor().asnumpy() y3 = idx.todgltensor().asnumpy()
assert np.allclose(ans, y1) assert np.allclose(ans, y1)
...@@ -59,7 +59,7 @@ def test_index(): ...@@ -59,7 +59,7 @@ def test_index():
# from list # from list
data = [10] * 10 data = [10] * 10
idx = toindex(data) idx = toindex(data)
y1 = idx.tolist() y1 = idx.tonumpy()
y2 = idx.tousertensor().numpy() y2 = idx.tousertensor().numpy()
y3 = idx.todgltensor().asnumpy() y3 = idx.todgltensor().asnumpy()
assert np.allclose(ans, y1) assert np.allclose(ans, y1)
...@@ -69,7 +69,7 @@ def test_index(): ...@@ -69,7 +69,7 @@ def test_index():
# from torch # from torch
data = th.ones((10,), dtype=th.int64) * 10 data = th.ones((10,), dtype=th.int64) * 10
idx = toindex(data) idx = toindex(data)
y1 = idx.tolist() y1 = idx.tonumpy()
y2 = idx.tousertensor().numpy() y2 = idx.tousertensor().numpy()
y3 = idx.todgltensor().asnumpy() y3 = idx.todgltensor().asnumpy()
assert np.allclose(ans, y1) assert np.allclose(ans, y1)
...@@ -79,7 +79,7 @@ def test_index(): ...@@ -79,7 +79,7 @@ def test_index():
# from dgl.NDArray # from dgl.NDArray
data = dgl.ndarray.array(np.ones((10,), dtype=np.int64) * 10) data = dgl.ndarray.array(np.ones((10,), dtype=np.int64) * 10)
idx = toindex(data) idx = toindex(data)
y1 = idx.tolist() y1 = idx.tonumpy()
y2 = idx.tousertensor().numpy() y2 = idx.tousertensor().numpy()
y3 = idx.todgltensor().asnumpy() y3 = idx.todgltensor().asnumpy()
assert np.allclose(ans, y1) assert np.allclose(ans, y1)
......
...@@ -46,13 +46,13 @@ Tree LSTM DGL Tutorial ...@@ -46,13 +46,13 @@ Tree LSTM DGL Tutorial
# #
import dgl import dgl
import dgl.data as data from dgl.data.tree import SST
# Each sample in the dataset is a constituency tree. The leaf nodes # Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is a int value stored in the "x" field. # represent words. The word is a int value stored in the "x" field.
# The non-leaf nodes has a special word PAD_WORD. The sentiment # The non-leaf nodes has a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field. # label is stored in the "y" feature field.
trainset = data.SST(mode='tiny') # the "tiny" set has only 5 trees trainset = SST(mode='tiny') # the "tiny" set has only 5 trees
tiny_sst = trainset.trees tiny_sst = trainset.trees
num_vocabs = trainset.num_vocabs num_vocabs = trainset.num_vocabs
num_classes = trainset.num_classes num_classes = trainset.num_classes
...@@ -337,7 +337,7 @@ optimizer = th.optim.Adagrad(model.parameters(), ...@@ -337,7 +337,7 @@ optimizer = th.optim.Adagrad(model.parameters(),
train_loader = DataLoader(dataset=tiny_sst, train_loader = DataLoader(dataset=tiny_sst,
batch_size=5, batch_size=5,
collate_fn=data.SST.batcher(device), collate_fn=SST.batcher(device),
shuffle=False, shuffle=False,
num_workers=0) num_workers=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment