Unverified Commit 2b77ad41 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Example] Fix GCMC (#2067)



* Fix

* Update data.py

* Update README.md

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-1-5.us-west-2.compute.internal>
parent 83654fdc
......@@ -293,7 +293,8 @@ class MovieLens(object):
(ones, rating_pairs),
shape=(self.num_user, self.num_movie), dtype=np.float32)
g = dgl.bipartite_from_scipy(user_movie_ratings_coo, utype='_U', etype='_E', vtype='_V')
return dgl.heterograph({('user', 'rate', 'movie'): g.edges()})
return dgl.heterograph({('user', 'rate', 'movie'): g.edges()},
num_nodes_dict={'user': self.num_user, 'movie': self.num_movie})
@property
def num_links(self):
......
......@@ -12,6 +12,8 @@ Credit: Jiani Zhang ([@jennyzhang0215](https://github.com/jennyzhang0215))
* PyTorch 1.2+
* pandas
* torchtext 0.4+ (if using user and item contents as node features)
* spacy (if using user and item contents as node features)
- You will also need to run `python -m spacy download en_core_web_sm`
## Data
......
......@@ -302,7 +302,8 @@ class MovieLens(object):
(ones, rating_pairs),
shape=(self.num_user, self.num_movie), dtype=np.float32)
g = dgl.bipartite_from_scipy(user_movie_ratings_coo, utype='_U', etype='_E', vtype='_V')
return dgl.heterograph({('user', 'rate', 'movie'): g.edges()})
return dgl.heterograph({('user', 'rate', 'movie'): g.edges()},
num_nodes_dict={'user': self.num_user, 'movie': self.num_movie})
@property
def num_links(self):
......
......@@ -121,6 +121,7 @@ def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'):
frontier = frontier.to(dev_id)
head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id)
pair_graph = pair_graph.to(dev_id)
with th.no_grad():
pred_ratings = net(pair_graph, frontier,
head_feat, tail_feat, possible_rating_values)
......@@ -335,7 +336,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
dataset=dataset,
dataloader=valid_dataloader,
segment='valid')
logging_str += ',\tVal RMSE={:.4f}'.format(valid_rmse)
logging_str = 'Val RMSE={:.4f}'.format(valid_rmse)
if valid_rmse < best_valid_rmse:
best_valid_rmse = valid_rmse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment