Unverified Commit 28deee4d authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model] fix PinSAGE example for 0.5 release (#2072)

parent dba8c825
...@@ -106,7 +106,7 @@ class PandasGraphBuilder(object): ...@@ -106,7 +106,7 @@ class PandasGraphBuilder(object):
dsttype = self.entity_pk_to_name[destination_key] dsttype = self.entity_pk_to_name[destination_key]
etype = (srctype, name, dsttype) etype = (srctype, name, dsttype)
self.relation_name_to_etype[name] = etype self.relation_name_to_etype[name] = etype
self.edges_per_relation[etype] = (src.cat.codes.values, dst.cat.codes.values) self.edges_per_relation[etype] = (src.cat.codes.values.astype('int64'), dst.cat.codes.values.astype('int64'))
self.relation_tables[name] = relation_table self.relation_tables[name] = relation_table
self.relation_src_key[name] = source_key self.relation_src_key[name] = source_key
self.relation_dst_key[name] = destination_key self.relation_dst_key[name] = destination_key
......
...@@ -8,7 +8,7 @@ import dask.dataframe as dd ...@@ -8,7 +8,7 @@ import dask.dataframe as dd
# This is the train-test split method most of the recommender system papers running on MovieLens # This is the train-test split method most of the recommender system papers running on MovieLens
# takes. It essentially follows the intuition of "training on the past and predict the future". # takes. It essentially follows the intuition of "training on the past and predict the future".
# One can also change the threshold to make validation and test set take larger proportions. # One can also change the threshold to make validation and test set take larger proportions.
def train_test_split_by_time(df, timestamp, item): def train_test_split_by_time(df, timestamp, user):
df['train_mask'] = np.ones((len(df),), dtype=np.bool) df['train_mask'] = np.ones((len(df),), dtype=np.bool)
df['val_mask'] = np.zeros((len(df),), dtype=np.bool) df['val_mask'] = np.zeros((len(df),), dtype=np.bool)
df['test_mask'] = np.zeros((len(df),), dtype=np.bool) df['test_mask'] = np.zeros((len(df),), dtype=np.bool)
...@@ -22,8 +22,8 @@ def train_test_split_by_time(df, timestamp, item): ...@@ -22,8 +22,8 @@ def train_test_split_by_time(df, timestamp, item):
df.iloc[-2, -3] = False df.iloc[-2, -3] = False
df.iloc[-2, -2] = True df.iloc[-2, -2] = True
return df return df
df = df.groupby(item).apply(train_test_split).compute(scheduler='processes').sort_index() df = df.groupby(user, group_keys=False).apply(train_test_split).compute(scheduler='processes').sort_index()
print(df[df[item] == df[item].unique()[0]].sort_values(timestamp)) print(df[df[user] == df[user].unique()[0]].sort_values(timestamp))
return df['train_mask'].to_numpy().nonzero()[0], \ return df['train_mask'].to_numpy().nonzero()[0], \
df['val_mask'].to_numpy().nonzero()[0], \ df['val_mask'].to_numpy().nonzero()[0], \
df['test_mask'].to_numpy().nonzero()[0] df['test_mask'].to_numpy().nonzero()[0]
......
...@@ -116,10 +116,11 @@ if __name__ == '__main__': ...@@ -116,10 +116,11 @@ if __name__ == '__main__':
# Train-validation-test split # Train-validation-test split
# This is a little bit tricky as we want to select the last interaction for test, and the # This is a little bit tricky as we want to select the last interaction for test, and the
# second-to-last interaction for validation. # second-to-last interaction for validation.
train_indices, val_indices, test_indices = train_test_split_by_time(ratings, 'timestamp', 'movie_id') train_indices, val_indices, test_indices = train_test_split_by_time(ratings, 'timestamp', 'user_id')
# Build the graph with training interactions only. # Build the graph with training interactions only.
train_g = build_train_graph(g, train_indices, 'user', 'movie', 'watched', 'watched-by') train_g = build_train_graph(g, train_indices, 'user', 'movie', 'watched', 'watched-by')
assert train_g.out_degrees(etype='watched').min() > 0
# Build the user-item sparse matrix for validation and test set. # Build the user-item sparse matrix for validation and test set.
val_matrix, test_matrix = build_val_test_matrix(g, val_indices, test_indices, 'user', 'movie', 'watched') val_matrix, test_matrix = build_val_test_matrix(g, val_indices, test_indices, 'user', 'movie', 'watched')
......
...@@ -53,8 +53,9 @@ if __name__ == '__main__': ...@@ -53,8 +53,9 @@ if __name__ == '__main__':
g.edges['listened-by'].data['created_at'] = torch.LongTensor(events['created_at'].values) g.edges['listened-by'].data['created_at'] = torch.LongTensor(events['created_at'].values)
n_edges = g.number_of_edges('listened') n_edges = g.number_of_edges('listened')
train_indices, val_indices, test_indices = train_test_split_by_time(events, 'created_at', 'track_id') train_indices, val_indices, test_indices = train_test_split_by_time(events, 'created_at', 'user_id')
train_g = build_train_graph(g, train_indices, 'user', 'track', 'listened', 'listened-by') train_g = build_train_graph(g, train_indices, 'user', 'track', 'listened', 'listened-by')
assert train_g.out_degrees(etype='listened').min() > 0
val_matrix, test_matrix = build_val_test_matrix( val_matrix, test_matrix = build_val_test_matrix(
g, val_indices, test_indices, 'user', 'track', 'listened') g, val_indices, test_indices, 'user', 'track', 'listened')
......
...@@ -68,12 +68,10 @@ class NeighborSampler(object): ...@@ -68,12 +68,10 @@ class NeighborSampler(object):
# connections only. # connections only.
pos_graph = dgl.graph( pos_graph = dgl.graph(
(heads, tails), (heads, tails),
num_nodes=self.g.number_of_nodes(self.item_type), num_nodes=self.g.number_of_nodes(self.item_type))
ntype=self.item_type)
neg_graph = dgl.graph( neg_graph = dgl.graph(
(heads, neg_tails), (heads, neg_tails),
num_nodes=self.g.number_of_nodes(self.item_type), num_nodes=self.g.number_of_nodes(self.item_type))
ntype=self.item_type)
pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph]) pos_graph, neg_graph = dgl.compact_graphs([pos_graph, neg_graph])
seeds = pos_graph.ndata[dgl.NID] seeds = pos_graph.ndata[dgl.NID]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment