"tests/compute/test_traversal.py" did not exist on "66676a548dd5ef77c8dcafe5218c04e572a4f2fb"
Unverified Commit cd6d1138 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[BugFix] fix dimension unmatch issue and legacy issue of torchtext (#3539)

parent 4bf70f09
......@@ -171,7 +171,7 @@ class ItemToItemScorer(nn.Module):
super().__init__()
n_nodes = full_graph.number_of_nodes(ntype)
self.bias = nn.Parameter(torch.zeros(n_nodes))
self.bias = nn.Parameter(torch.zeros(n_nodes, 1))
def _add_bias(self, edges):
bias_src = self.bias[edges.src[dgl.NID]]
......
......@@ -52,13 +52,13 @@ def train(dataset, args):
fields = {}
examples = []
for key, texts in item_texts.items():
fields[key] = torchtext.data.Field(include_lengths=True, lower=True, batch_first=True)
fields[key] = torchtext.legacy.data.Field(include_lengths=True, lower=True, batch_first=True)
for i in range(g.number_of_nodes(item_ntype)):
example = torchtext.data.Example.fromlist(
example = torchtext.legacy.data.Example.fromlist(
[item_texts[key][i] for key in item_texts.keys()],
[(key, fields[key]) for key in item_texts.keys()])
examples.append(example)
textset = torchtext.data.Dataset(examples, fields)
textset = torchtext.legacy.data.Dataset(examples, fields)
for key, field in fields.items():
field.build_vocab(getattr(textset, key))
#field.build_vocab(getattr(textset, key), vectors='fasttext.simple.300d')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment