"docs/source/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "279e2e8fbab2f8a09449a67079c071f9ca6cfaac"
Unverified Commit 2b77ad41 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Example] Fix GCMC (#2067)



* Fix

* Update data.py

* Update README.md

* Fix
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-1-5.us-west-2.compute.internal>
parent 83654fdc
...@@ -293,7 +293,8 @@ class MovieLens(object): ...@@ -293,7 +293,8 @@ class MovieLens(object):
(ones, rating_pairs), (ones, rating_pairs),
shape=(self.num_user, self.num_movie), dtype=np.float32) shape=(self.num_user, self.num_movie), dtype=np.float32)
g = dgl.bipartite_from_scipy(user_movie_ratings_coo, utype='_U', etype='_E', vtype='_V') g = dgl.bipartite_from_scipy(user_movie_ratings_coo, utype='_U', etype='_E', vtype='_V')
return dgl.heterograph({('user', 'rate', 'movie'): g.edges()}) return dgl.heterograph({('user', 'rate', 'movie'): g.edges()},
num_nodes_dict={'user': self.num_user, 'movie': self.num_movie})
@property @property
def num_links(self): def num_links(self):
......
...@@ -12,6 +12,8 @@ Credit: Jiani Zhang ([@jennyzhang0215](https://github.com/jennyzhang0215)) ...@@ -12,6 +12,8 @@ Credit: Jiani Zhang ([@jennyzhang0215](https://github.com/jennyzhang0215))
* PyTorch 1.2+ * PyTorch 1.2+
* pandas * pandas
* torchtext 0.4+ (if using user and item contents as node features) * torchtext 0.4+ (if using user and item contents as node features)
* spacy (if using user and item contents as node features)
- You will also need to run `python -m spacy download en_core_web_sm`
## Data ## Data
......
...@@ -302,7 +302,8 @@ class MovieLens(object): ...@@ -302,7 +302,8 @@ class MovieLens(object):
(ones, rating_pairs), (ones, rating_pairs),
shape=(self.num_user, self.num_movie), dtype=np.float32) shape=(self.num_user, self.num_movie), dtype=np.float32)
g = dgl.bipartite_from_scipy(user_movie_ratings_coo, utype='_U', etype='_E', vtype='_V') g = dgl.bipartite_from_scipy(user_movie_ratings_coo, utype='_U', etype='_E', vtype='_V')
return dgl.heterograph({('user', 'rate', 'movie'): g.edges()}) return dgl.heterograph({('user', 'rate', 'movie'): g.edges()},
num_nodes_dict={'user': self.num_user, 'movie': self.num_movie})
@property @property
def num_links(self): def num_links(self):
......
...@@ -121,6 +121,7 @@ def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'): ...@@ -121,6 +121,7 @@ def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'):
frontier = frontier.to(dev_id) frontier = frontier.to(dev_id)
head_feat = head_feat.to(dev_id) head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id) tail_feat = tail_feat.to(dev_id)
pair_graph = pair_graph.to(dev_id)
with th.no_grad(): with th.no_grad():
pred_ratings = net(pair_graph, frontier, pred_ratings = net(pair_graph, frontier,
head_feat, tail_feat, possible_rating_values) head_feat, tail_feat, possible_rating_values)
...@@ -335,7 +336,7 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -335,7 +336,7 @@ def run(proc_id, n_gpus, args, devices, dataset):
dataset=dataset, dataset=dataset,
dataloader=valid_dataloader, dataloader=valid_dataloader,
segment='valid') segment='valid')
logging_str += ',\tVal RMSE={:.4f}'.format(valid_rmse) logging_str = 'Val RMSE={:.4f}'.format(valid_rmse)
if valid_rmse < best_valid_rmse: if valid_rmse < best_valid_rmse:
best_valid_rmse = valid_rmse best_valid_rmse = valid_rmse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment