Unverified Commit 23d09057 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4642)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a9f2acf3
import dgl
import numpy as np import numpy as np
import torch as th import torch as th
import dgl
class Sampler: class Sampler:
def __init__(self, def __init__(
graph, self, graph, walk_length, num_walks, window_size, num_negative
walk_length, ):
num_walks,
window_size,
num_negative):
self.graph = graph self.graph = graph
self.walk_length = walk_length self.walk_length = walk_length
self.num_walks = num_walks self.num_walks = num_walks
...@@ -19,8 +17,8 @@ class Sampler: ...@@ -19,8 +17,8 @@ class Sampler:
def sample(self, batch, sku_info): def sample(self, batch, sku_info):
""" """
Given a batch of target nodes, sample postive Given a batch of target nodes, sample postive
pairs and negative pairs from the graph pairs and negative pairs from the graph
""" """
batch = np.repeat(batch, self.num_walks) batch = np.repeat(batch, self.num_walks)
...@@ -46,17 +44,14 @@ class Sampler: ...@@ -46,17 +44,14 @@ class Sampler:
def generate_pos_pairs(self, nodes): def generate_pos_pairs(self, nodes):
""" """
For seq [1, 2, 3, 4] and node NO.2, For seq [1, 2, 3, 4] and node NO.2,
the window_size=1 will generate: the window_size=1 will generate:
(1, 2) and (2, 3) (1, 2) and (2, 3)
""" """
# random walk # random walk
traces, types = dgl.sampling.random_walk( traces, types = dgl.sampling.random_walk(
g=self.graph, g=self.graph, nodes=nodes, length=self.walk_length, prob="weight"
nodes=nodes, )
length=self.walk_length,
prob="weight"
)
traces = traces.tolist() traces = traces.tolist()
self.filter_padding(traces) self.filter_padding(traces)
...@@ -68,32 +63,32 @@ class Sampler: ...@@ -68,32 +63,32 @@ class Sampler:
left = max(0, i - self.window_size) left = max(0, i - self.window_size)
right = min(len(trace), i + self.window_size + 1) right = min(len(trace), i + self.window_size + 1)
pairs.extend([[center, x, 1] for x in trace[left:i]]) pairs.extend([[center, x, 1] for x in trace[left:i]])
pairs.extend([[center, x, 1] for x in trace[i+1:right]]) pairs.extend([[center, x, 1] for x in trace[i + 1 : right]])
return pairs return pairs
def compute_node_sample_weight(self): def compute_node_sample_weight(self):
""" """
Using node degree as sample weight Using node degree as sample weight
""" """
return self.graph.in_degrees().float() return self.graph.in_degrees().float()
def generate_neg_pairs(self, pos_pairs): def generate_neg_pairs(self, pos_pairs):
""" """
Sample based on node freq in traces, frequently shown Sample based on node freq in traces, frequently shown
nodes will have larger chance to be sampled as nodes will have larger chance to be sampled as
negative node. negative node.
""" """
# sample `self.num_negative` neg dst node # sample `self.num_negative` neg dst node
# for each pos node pair's src node. # for each pos node pair's src node.
negs = th.multinomial( negs = th.multinomial(
self.node_weights, self.node_weights,
len(pos_pairs) * self.num_negative, len(pos_pairs) * self.num_negative,
replacement=True replacement=True,
).tolist() ).tolist()
tar = np.repeat([pair[0] for pair in pos_pairs], self.num_negative) tar = np.repeat([pair[0] for pair in pos_pairs], self.num_negative)
assert(len(tar) == len(negs)) assert len(tar) == len(negs)
neg_pairs = [[x, y, 0] for x, y in zip(tar, negs)] neg_pairs = [[x, y, 0] for x, y in zip(tar, negs)]
return neg_pairs return neg_pairs
import dgl
import random
import argparse import argparse
import torch as th import random
import numpy as np
import networkx as nx
from datetime import datetime from datetime import datetime
import networkx as nx
import numpy as np
import torch as th
import dgl
def init_args(): def init_args():
# TODO: change args # TODO: change args
argparser = argparse.ArgumentParser() argparser = argparse.ArgumentParser()
argparser.add_argument('--session_interval_sec', type=int, default=1800) argparser.add_argument("--session_interval_sec", type=int, default=1800)
argparser.add_argument('--action_data', type=str, default="data/action_head.csv") argparser.add_argument(
argparser.add_argument('--item_info_data', type=str, "--action_data", type=str, default="data/action_head.csv"
default="data/jdata_product.csv") )
argparser.add_argument('--walk_length', type=int, default=10) argparser.add_argument(
argparser.add_argument('--num_walks', type=int, default=5) "--item_info_data", type=str, default="data/jdata_product.csv"
argparser.add_argument('--batch_size', type=int, default=64) )
argparser.add_argument('--dim', type=int, default=16) argparser.add_argument("--walk_length", type=int, default=10)
argparser.add_argument('--epochs', type=int, default=30) argparser.add_argument("--num_walks", type=int, default=5)
argparser.add_argument('--window_size', type=int, default=2) argparser.add_argument("--batch_size", type=int, default=64)
argparser.add_argument('--num_negative', type=int, default=5) argparser.add_argument("--dim", type=int, default=16)
argparser.add_argument('--lr', type=float, default=0.001) argparser.add_argument("--epochs", type=int, default=30)
argparser.add_argument('--log_every', type=int, default=100) argparser.add_argument("--window_size", type=int, default=2)
argparser.add_argument("--num_negative", type=int, default=5)
argparser.add_argument("--lr", type=float, default=0.001)
argparser.add_argument("--log_every", type=int, default=100)
return argparser.parse_args() return argparser.parse_args()
def construct_graph(datapath, session_interval_gap_sec, valid_sku_raw_ids): def construct_graph(datapath, session_interval_gap_sec, valid_sku_raw_ids):
user_clicks, sku_encoder, sku_decoder = parse_actions(datapath, valid_sku_raw_ids) user_clicks, sku_encoder, sku_decoder = parse_actions(
datapath, valid_sku_raw_ids
)
# {src,dst: weight} # {src,dst: weight}
graph = {} graph = {}
for user_id, action_list in user_clicks.items(): for user_id, action_list in user_clicks.items():
# sort by action time # sort by action time
_action_list = sorted(action_list, key=lambda x: x[1]) _action_list = sorted(action_list, key=lambda x: x[1])
last_action_time = datetime.strptime(_action_list[0][1], "%Y-%m-%d %H:%M:%S") last_action_time = datetime.strptime(
_action_list[0][1], "%Y-%m-%d %H:%M:%S"
)
session = [_action_list[0][0]] session = [_action_list[0][0]]
# cut sessions and add to graph # cut sessions and add to graph
for sku_id, action_time in _action_list[1:]: for sku_id, action_time in _action_list[1:]:
...@@ -52,7 +61,7 @@ def construct_graph(datapath, session_interval_gap_sec, valid_sku_raw_ids): ...@@ -52,7 +61,7 @@ def construct_graph(datapath, session_interval_gap_sec, valid_sku_raw_ids):
session = [sku_id] session = [sku_id]
# add last session # add last session
add_session(session, graph) add_session(session, graph)
g = convert_to_dgl_graph(graph) g = convert_to_dgl_graph(graph)
return g, sku_encoder, sku_decoder return g, sku_encoder, sku_decoder
...@@ -66,20 +75,20 @@ def convert_to_dgl_graph(graph): ...@@ -66,20 +75,20 @@ def convert_to_dgl_graph(graph):
src, dst = int(nodes[0]), int(nodes[1]) src, dst = int(nodes[0]), int(nodes[1])
g.add_edge(src, dst, weight=float(weight)) g.add_edge(src, dst, weight=float(weight))
return dgl.from_networkx(g, edge_attrs=['weight']) return dgl.from_networkx(g, edge_attrs=["weight"])
def add_session(session, graph): def add_session(session, graph):
""" """
For session like: For session like:
[sku1, sku2, sku3] [sku1, sku2, sku3]
add 1 weight to each of the following edges: add 1 weight to each of the following edges:
sku1 -> sku2 sku1 -> sku2
sku2 -> sku3 sku2 -> sku3
If sesson length < 2, no nodes/edges will be added If sesson length < 2, no nodes/edges will be added
""" """
for i in range(len(session)-1): for i in range(len(session) - 1):
edge = str(session[i]) + "," + str(session[i+1]) edge = str(session[i]) + "," + str(session[i + 1])
try: try:
graph[edge] += 1 graph[edge] += 1
except KeyError: except KeyError:
...@@ -104,17 +113,16 @@ def parse_actions(datapath, valid_sku_raw_ids): ...@@ -104,17 +113,16 @@ def parse_actions(datapath, valid_sku_raw_ids):
if sku_raw_id in valid_sku_raw_ids: if sku_raw_id in valid_sku_raw_ids:
action_time = fields[2] action_time = fields[2]
# encode sku_id # encode sku_id
sku_id = encode_id(sku_encoder, sku_id = encode_id(
sku_decoder, sku_encoder, sku_decoder, sku_raw_id, sku_id
sku_raw_id, )
sku_id)
# add to user clicks # add to user clicks
try: try:
user_clicks[user_id].append((sku_id, action_time)) user_clicks[user_id].append((sku_id, action_time))
except KeyError: except KeyError:
user_clicks[user_id] = [(sku_id, action_time)] user_clicks[user_id] = [(sku_id, action_time)]
return user_clicks, sku_encoder, sku_decoder return user_clicks, sku_encoder, sku_decoder
...@@ -136,7 +144,7 @@ def get_valid_sku_set(datapath): ...@@ -136,7 +144,7 @@ def get_valid_sku_set(datapath):
line.replace("\n", "") line.replace("\n", "")
sku_raw_id = line.split(",")[0] sku_raw_id = line.split(",")[0]
sku_ids.add(sku_raw_id) sku_ids.add(sku_raw_id)
return sku_ids return sku_ids
...@@ -159,27 +167,27 @@ def encode_sku_fields(datapath, sku_encoder, sku_decoder): ...@@ -159,27 +167,27 @@ def encode_sku_fields(datapath, sku_encoder, sku_decoder):
if sku_raw_id in sku_encoder: if sku_raw_id in sku_encoder:
sku_id = sku_encoder[sku_raw_id] sku_id = sku_encoder[sku_raw_id]
brand_id = encode_id( brand_id = encode_id(
sku_info_encoder["brand"], sku_info_encoder["brand"],
sku_info_decoder["brand"], sku_info_decoder["brand"],
brand_raw_id, brand_raw_id,
brand_id brand_id,
) )
shop_id = encode_id( shop_id = encode_id(
sku_info_encoder["shop"], sku_info_encoder["shop"],
sku_info_decoder["shop"], sku_info_decoder["shop"],
shop_raw_id, shop_raw_id,
shop_id shop_id,
) )
cate_id = encode_id( cate_id = encode_id(
sku_info_encoder["cate"], sku_info_encoder["cate"],
sku_info_decoder["cate"], sku_info_decoder["cate"],
cate_raw_id, cate_raw_id,
cate_id cate_id,
) )
sku_info[sku_id] = [sku_id, brand_id, shop_id, cate_id] sku_info[sku_id] = [sku_id, brand_id, shop_id, cate_id]
...@@ -195,20 +203,22 @@ class TestEdge: ...@@ -195,20 +203,22 @@ class TestEdge:
def split_train_test_graph(graph): def split_train_test_graph(graph):
""" """
For test true edges, 1/3 of the edges are randomly chosen For test true edges, 1/3 of the edges are randomly chosen
and removed as ground truth in the test set, and removed as ground truth in the test set,
the remaining graph is taken as the training set. the remaining graph is taken as the training set.
""" """
test_edges = [] test_edges = []
neg_sampler = dgl.dataloading.negative_sampler.Uniform(1) neg_sampler = dgl.dataloading.negative_sampler.Uniform(1)
sampled_edge_ids = random.sample(range(graph.num_edges()), int(graph.num_edges() / 3)) sampled_edge_ids = random.sample(
range(graph.num_edges()), int(graph.num_edges() / 3)
)
for edge_id in sampled_edge_ids: for edge_id in sampled_edge_ids:
src, dst = graph.find_edges(edge_id) src, dst = graph.find_edges(edge_id)
test_edges.append(TestEdge(src, dst, 1)) test_edges.append(TestEdge(src, dst, 1))
src, dst = neg_sampler(graph, th.tensor([edge_id])) src, dst = neg_sampler(graph, th.tensor([edge_id]))
test_edges.append(TestEdge(src, dst, 0)) test_edges.append(TestEdge(src, dst, 0))
graph.remove_edges(sampled_edge_ids) graph.remove_edges(sampled_edge_ids)
test_graph = test_edges test_graph = test_edges
......
import os import os
import pandas
import numpy import numpy
import pandas
import torch import torch
import dgl import dgl
...@@ -14,35 +16,52 @@ def process_raw_data(raw_dir, processed_dir): ...@@ -14,35 +16,52 @@ def process_raw_data(raw_dir, processed_dir):
github.com/IBM/EvolveGCN/blob/master/elliptic_construction.md github.com/IBM/EvolveGCN/blob/master/elliptic_construction.md
The main purpose is to convert original idx to contiguous idx start at 0. The main purpose is to convert original idx to contiguous idx start at 0.
""" """
oid_nid_path = os.path.join(processed_dir, 'oid_nid.npy') oid_nid_path = os.path.join(processed_dir, "oid_nid.npy")
id_label_path = os.path.join(processed_dir, 'id_label.npy') id_label_path = os.path.join(processed_dir, "id_label.npy")
id_time_features_path = os.path.join(processed_dir, 'id_time_features.npy') id_time_features_path = os.path.join(processed_dir, "id_time_features.npy")
src_dst_time_path = os.path.join(processed_dir, 'src_dst_time.npy') src_dst_time_path = os.path.join(processed_dir, "src_dst_time.npy")
if os.path.exists(oid_nid_path) and os.path.exists(id_label_path) and \ if (
os.path.exists(id_time_features_path) and os.path.exists(src_dst_time_path): os.path.exists(oid_nid_path)
print("The preprocessed data already exists, skip the preprocess stage!") and os.path.exists(id_label_path)
and os.path.exists(id_time_features_path)
and os.path.exists(src_dst_time_path)
):
print(
"The preprocessed data already exists, skip the preprocess stage!"
)
return return
print("starting process raw data in {}".format(raw_dir)) print("starting process raw data in {}".format(raw_dir))
id_label = pandas.read_csv(os.path.join(raw_dir, 'elliptic_txs_classes.csv')) id_label = pandas.read_csv(
src_dst = pandas.read_csv(os.path.join(raw_dir, 'elliptic_txs_edgelist.csv')) os.path.join(raw_dir, "elliptic_txs_classes.csv")
)
src_dst = pandas.read_csv(
os.path.join(raw_dir, "elliptic_txs_edgelist.csv")
)
# elliptic_txs_features.csv has no header, and it has the same order idx with elliptic_txs_classes.csv # elliptic_txs_features.csv has no header, and it has the same order idx with elliptic_txs_classes.csv
id_time_features = pandas.read_csv(os.path.join(raw_dir, 'elliptic_txs_features.csv'), header=None) id_time_features = pandas.read_csv(
os.path.join(raw_dir, "elliptic_txs_features.csv"), header=None
)
# get oldId_newId # get oldId_newId
oid_nid = id_label.loc[:, ['txId']] oid_nid = id_label.loc[:, ["txId"]]
oid_nid = oid_nid.rename(columns={'txId': 'originalId'}) oid_nid = oid_nid.rename(columns={"txId": "originalId"})
oid_nid.insert(1, 'newId', range(len(oid_nid))) oid_nid.insert(1, "newId", range(len(oid_nid)))
# map classes unknown,1,2 to -1,1,0 and construct id_label. type 1 means illicit. # map classes unknown,1,2 to -1,1,0 and construct id_label. type 1 means illicit.
id_label = pandas.concat( id_label = pandas.concat(
[oid_nid['newId'], id_label['class'].map({'unknown': -1.0, '1': 1.0, '2': 0.0})], axis=1) [
oid_nid["newId"],
id_label["class"].map({"unknown": -1.0, "1": 1.0, "2": 0.0}),
],
axis=1,
)
# replace originalId to newId. # replace originalId to newId.
# Attention: the timestamp in features start at 1. # Attention: the timestamp in features start at 1.
id_time_features[0] = oid_nid['newId'] id_time_features[0] = oid_nid["newId"]
# construct originalId2newId dict # construct originalId2newId dict
oid_nid_dict = oid_nid.set_index(['originalId'])['newId'].to_dict() oid_nid_dict = oid_nid.set_index(["originalId"])["newId"].to_dict()
# construct newId2timestamp dict # construct newId2timestamp dict
nid_time_dict = id_time_features.set_index([0])[1].to_dict() nid_time_dict = id_time_features.set_index([0])[1].to_dict()
...@@ -53,9 +72,9 @@ def process_raw_data(raw_dir, processed_dir): ...@@ -53,9 +72,9 @@ def process_raw_data(raw_dir, processed_dir):
# In EvolveGCN example, the edge timestamp will not be used. # In EvolveGCN example, the edge timestamp will not be used.
# #
# Note: in the dataset, src and dst node has the same timestamp, so it's easy to set edge's timestamp. # Note: in the dataset, src and dst node has the same timestamp, so it's easy to set edge's timestamp.
new_src = src_dst['txId1'].map(oid_nid_dict).rename('newSrc') new_src = src_dst["txId1"].map(oid_nid_dict).rename("newSrc")
new_dst = src_dst['txId2'].map(oid_nid_dict).rename('newDst') new_dst = src_dst["txId2"].map(oid_nid_dict).rename("newDst")
edge_time = new_src.map(nid_time_dict).rename('timestamp') edge_time = new_src.map(nid_time_dict).rename("timestamp")
src_dst_time = pandas.concat([new_src, new_dst, edge_time], axis=1) src_dst_time = pandas.concat([new_src, new_dst, edge_time], axis=1)
# save oid_nid, id_label, id_time_features, src_dst_time to disk. we can convert them to numpy. # save oid_nid, id_label, id_time_features, src_dst_time to disk. we can convert them to numpy.
...@@ -69,11 +88,17 @@ def process_raw_data(raw_dir, processed_dir): ...@@ -69,11 +88,17 @@ def process_raw_data(raw_dir, processed_dir):
numpy.save(id_label_path, id_label) numpy.save(id_label_path, id_label)
numpy.save(id_time_features_path, id_time_features) numpy.save(id_time_features_path, id_time_features)
numpy.save(src_dst_time_path, src_dst_time) numpy.save(src_dst_time_path, src_dst_time)
print("Process Elliptic raw data done, data has saved into {}".format(processed_dir)) print(
"Process Elliptic raw data done, data has saved into {}".format(
processed_dir
)
)
class EllipticDataset: class EllipticDataset:
def __init__(self, raw_dir, processed_dir, self_loop=True, reverse_edge=True): def __init__(
self, raw_dir, processed_dir, self_loop=True, reverse_edge=True
):
self.raw_dir = raw_dir self.raw_dir = raw_dir
self.processd_dir = processed_dir self.processd_dir = processed_dir
self.self_loop = self_loop self.self_loop = self_loop
...@@ -81,36 +106,63 @@ class EllipticDataset: ...@@ -81,36 +106,63 @@ class EllipticDataset:
def process(self): def process(self):
process_raw_data(self.raw_dir, self.processd_dir) process_raw_data(self.raw_dir, self.processd_dir)
id_time_features = torch.Tensor(numpy.load(os.path.join(self.processd_dir, 'id_time_features.npy'))) id_time_features = torch.Tensor(
id_label = torch.IntTensor(numpy.load(os.path.join(self.processd_dir, 'id_label.npy'))) numpy.load(os.path.join(self.processd_dir, "id_time_features.npy"))
src_dst_time = torch.IntTensor(numpy.load(os.path.join(self.processd_dir, 'src_dst_time.npy'))) )
id_label = torch.IntTensor(
numpy.load(os.path.join(self.processd_dir, "id_label.npy"))
)
src_dst_time = torch.IntTensor(
numpy.load(os.path.join(self.processd_dir, "src_dst_time.npy"))
)
src = src_dst_time[:, 0] src = src_dst_time[:, 0]
dst = src_dst_time[:, 1] dst = src_dst_time[:, 1]
# id_label[:, 0] is used to add self loop # id_label[:, 0] is used to add self loop
if self.self_loop: if self.self_loop:
if self.reverse_edge: if self.reverse_edge:
g = dgl.graph(data=(torch.cat((src, dst, id_label[:, 0])), torch.cat((dst, src, id_label[:, 0]))), g = dgl.graph(
num_nodes=id_label.shape[0]) data=(
g.edata['timestamp'] = torch.cat((src_dst_time[:, 2], src_dst_time[:, 2], id_time_features[:, 1].int())) torch.cat((src, dst, id_label[:, 0])),
torch.cat((dst, src, id_label[:, 0])),
),
num_nodes=id_label.shape[0],
)
g.edata["timestamp"] = torch.cat(
(
src_dst_time[:, 2],
src_dst_time[:, 2],
id_time_features[:, 1].int(),
)
)
else: else:
g = dgl.graph(data=(torch.cat((src, id_label[:, 0])), torch.cat((dst, id_label[:, 0]))), g = dgl.graph(
num_nodes=id_label.shape[0]) data=(
g.edata['timestamp'] = torch.cat((src_dst_time[:, 2], id_time_features[:, 1].int())) torch.cat((src, id_label[:, 0])),
torch.cat((dst, id_label[:, 0])),
),
num_nodes=id_label.shape[0],
)
g.edata["timestamp"] = torch.cat(
(src_dst_time[:, 2], id_time_features[:, 1].int())
)
else: else:
if self.reverse_edge: if self.reverse_edge:
g = dgl.graph(data=(torch.cat((src, dst)), torch.cat((dst, src))), g = dgl.graph(
num_nodes=id_label.shape[0]) data=(torch.cat((src, dst)), torch.cat((dst, src))),
g.edata['timestamp'] = torch.cat((src_dst_time[:, 2], src_dst_time[:, 2])) num_nodes=id_label.shape[0],
)
g.edata["timestamp"] = torch.cat(
(src_dst_time[:, 2], src_dst_time[:, 2])
)
else: else:
g = dgl.graph(data=(src, dst), g = dgl.graph(data=(src, dst), num_nodes=id_label.shape[0])
num_nodes=id_label.shape[0]) g.edata["timestamp"] = src_dst_time[:, 2]
g.edata['timestamp'] = src_dst_time[:, 2]
time_features = id_time_features[:, 1:] time_features = id_time_features[:, 1:]
label = id_label[:, 1] label = id_label[:, 1]
g.ndata['label'] = label g.ndata["label"] = label
g.ndata['feat'] = time_features g.ndata["feat"] = time_features
# used to construct time-based sub-graph. # used to construct time-based sub-graph.
node_mask_by_time = [] node_mask_by_time = []
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import init from torch.nn import init
from dgl.nn.pytorch import GraphConv
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from dgl.nn.pytorch import GraphConv
class MatGRUCell(torch.nn.Module): class MatGRUCell(torch.nn.Module):
""" """
...@@ -13,17 +14,11 @@ class MatGRUCell(torch.nn.Module): ...@@ -13,17 +14,11 @@ class MatGRUCell(torch.nn.Module):
def __init__(self, in_feats, out_feats): def __init__(self, in_feats, out_feats):
super().__init__() super().__init__()
self.update = MatGRUGate(in_feats, self.update = MatGRUGate(in_feats, out_feats, torch.nn.Sigmoid())
out_feats,
torch.nn.Sigmoid())
self.reset = MatGRUGate(in_feats, self.reset = MatGRUGate(in_feats, out_feats, torch.nn.Sigmoid())
out_feats,
torch.nn.Sigmoid())
self.htilda = MatGRUGate(in_feats, self.htilda = MatGRUGate(in_feats, out_feats, torch.nn.Tanh())
out_feats,
torch.nn.Tanh())
def forward(self, prev_Q, z_topk=None): def forward(self, prev_Q, z_topk=None):
if z_topk is None: if z_topk is None:
...@@ -60,9 +55,9 @@ class MatGRUGate(torch.nn.Module): ...@@ -60,9 +55,9 @@ class MatGRUGate(torch.nn.Module):
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, x, hidden): def forward(self, x, hidden):
out = self.activation(self.W.matmul(x) + \ out = self.activation(
self.U.matmul(hidden) + \ self.W.matmul(x) + self.U.matmul(hidden) + self.bias
self.bias) )
return out return out
...@@ -85,15 +80,26 @@ class TopK(torch.nn.Module): ...@@ -85,15 +80,26 @@ class TopK(torch.nn.Module):
init.xavier_uniform_(self.scorer) init.xavier_uniform_(self.scorer)
def forward(self, node_embs): def forward(self, node_embs):
scores = node_embs.matmul(self.scorer) / self.scorer.norm().clamp(min=1e-6) scores = node_embs.matmul(self.scorer) / self.scorer.norm().clamp(
min=1e-6
)
vals, topk_indices = scores.view(-1).topk(self.k) vals, topk_indices = scores.view(-1).topk(self.k)
out = node_embs[topk_indices] * torch.tanh(scores[topk_indices].view(-1, 1)) out = node_embs[topk_indices] * torch.tanh(
scores[topk_indices].view(-1, 1)
)
# we need to transpose the output # we need to transpose the output
return out.t() return out.t()
class EvolveGCNH(nn.Module): class EvolveGCNH(nn.Module):
def __init__(self, in_feats=166, n_hidden=76, num_layers=2, n_classes=2, classifier_hidden=510): def __init__(
self,
in_feats=166,
n_hidden=76,
num_layers=2,
n_classes=2,
classifier_hidden=510,
):
# default parameters follow the official config # default parameters follow the official config
super(EvolveGCNH, self).__init__() super(EvolveGCNH, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
...@@ -104,20 +110,44 @@ class EvolveGCNH(nn.Module): ...@@ -104,20 +110,44 @@ class EvolveGCNH(nn.Module):
self.pooling_layers.append(TopK(in_feats, n_hidden)) self.pooling_layers.append(TopK(in_feats, n_hidden))
# similar to EvolveGCNO # similar to EvolveGCNO
self.recurrent_layers.append(MatGRUCell(in_feats=in_feats, out_feats=n_hidden)) self.recurrent_layers.append(
self.gcn_weights_list.append(Parameter(torch.Tensor(in_feats, n_hidden))) MatGRUCell(in_feats=in_feats, out_feats=n_hidden)
)
self.gcn_weights_list.append(
Parameter(torch.Tensor(in_feats, n_hidden))
)
self.gnn_convs.append( self.gnn_convs.append(
GraphConv(in_feats=in_feats, out_feats=n_hidden, bias=False, activation=nn.RReLU(), weight=False)) GraphConv(
in_feats=in_feats,
out_feats=n_hidden,
bias=False,
activation=nn.RReLU(),
weight=False,
)
)
for _ in range(num_layers - 1): for _ in range(num_layers - 1):
self.pooling_layers.append(TopK(n_hidden, n_hidden)) self.pooling_layers.append(TopK(n_hidden, n_hidden))
self.recurrent_layers.append(MatGRUCell(in_feats=n_hidden, out_feats=n_hidden)) self.recurrent_layers.append(
self.gcn_weights_list.append(Parameter(torch.Tensor(n_hidden, n_hidden))) MatGRUCell(in_feats=n_hidden, out_feats=n_hidden)
)
self.gcn_weights_list.append(
Parameter(torch.Tensor(n_hidden, n_hidden))
)
self.gnn_convs.append( self.gnn_convs.append(
GraphConv(in_feats=n_hidden, out_feats=n_hidden, bias=False, activation=nn.RReLU(), weight=False)) GraphConv(
in_feats=n_hidden,
self.mlp = nn.Sequential(nn.Linear(n_hidden, classifier_hidden), out_feats=n_hidden,
nn.ReLU(), bias=False,
nn.Linear(classifier_hidden, n_classes)) activation=nn.RReLU(),
weight=False,
)
)
self.mlp = nn.Sequential(
nn.Linear(n_hidden, classifier_hidden),
nn.ReLU(),
nn.Linear(classifier_hidden, n_classes),
)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -127,18 +157,27 @@ class EvolveGCNH(nn.Module): ...@@ -127,18 +157,27 @@ class EvolveGCNH(nn.Module):
def forward(self, g_list): def forward(self, g_list):
feature_list = [] feature_list = []
for g in g_list: for g in g_list:
feature_list.append(g.ndata['feat']) feature_list.append(g.ndata["feat"])
for i in range(self.num_layers): for i in range(self.num_layers):
W = self.gcn_weights_list[i] W = self.gcn_weights_list[i]
for j, g in enumerate(g_list): for j, g in enumerate(g_list):
X_tilde = self.pooling_layers[i](feature_list[j]) X_tilde = self.pooling_layers[i](feature_list[j])
W = self.recurrent_layers[i](W, X_tilde) W = self.recurrent_layers[i](W, X_tilde)
feature_list[j] = self.gnn_convs[i](g, feature_list[j], weight=W) feature_list[j] = self.gnn_convs[i](
g, feature_list[j], weight=W
)
return self.mlp(feature_list[-1]) return self.mlp(feature_list[-1])
class EvolveGCNO(nn.Module): class EvolveGCNO(nn.Module):
def __init__(self, in_feats=166, n_hidden=256, num_layers=2, n_classes=2, classifier_hidden=307): def __init__(
self,
in_feats=166,
n_hidden=256,
num_layers=2,
n_classes=2,
classifier_hidden=307,
):
# default parameters follow the official config # default parameters follow the official config
super(EvolveGCNO, self).__init__() super(EvolveGCNO, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
...@@ -154,19 +193,43 @@ class EvolveGCNO(nn.Module): ...@@ -154,19 +193,43 @@ class EvolveGCNO(nn.Module):
# but the performance is worse than use torch.nn.GRU. # but the performance is worse than use torch.nn.GRU.
# PPS: I think torch.nn.GRU can't match the manually implemented GRU cell in the official code, # PPS: I think torch.nn.GRU can't match the manually implemented GRU cell in the official code,
# we follow the official code here. # we follow the official code here.
self.recurrent_layers.append(MatGRUCell(in_feats=in_feats, out_feats=n_hidden)) self.recurrent_layers.append(
self.gcn_weights_list.append(Parameter(torch.Tensor(in_feats, n_hidden))) MatGRUCell(in_feats=in_feats, out_feats=n_hidden)
)
self.gcn_weights_list.append(
Parameter(torch.Tensor(in_feats, n_hidden))
)
self.gnn_convs.append( self.gnn_convs.append(
GraphConv(in_feats=in_feats, out_feats=n_hidden, bias=False, activation=nn.RReLU(), weight=False)) GraphConv(
in_feats=in_feats,
out_feats=n_hidden,
bias=False,
activation=nn.RReLU(),
weight=False,
)
)
for _ in range(num_layers - 1): for _ in range(num_layers - 1):
self.recurrent_layers.append(MatGRUCell(in_feats=n_hidden, out_feats=n_hidden)) self.recurrent_layers.append(
self.gcn_weights_list.append(Parameter(torch.Tensor(n_hidden, n_hidden))) MatGRUCell(in_feats=n_hidden, out_feats=n_hidden)
)
self.gcn_weights_list.append(
Parameter(torch.Tensor(n_hidden, n_hidden))
)
self.gnn_convs.append( self.gnn_convs.append(
GraphConv(in_feats=n_hidden, out_feats=n_hidden, bias=False, activation=nn.RReLU(), weight=False)) GraphConv(
in_feats=n_hidden,
self.mlp = nn.Sequential(nn.Linear(n_hidden, classifier_hidden), out_feats=n_hidden,
nn.ReLU(), bias=False,
nn.Linear(classifier_hidden, n_classes)) activation=nn.RReLU(),
weight=False,
)
)
self.mlp = nn.Sequential(
nn.Linear(n_hidden, classifier_hidden),
nn.ReLU(),
nn.Linear(classifier_hidden, n_classes),
)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -176,7 +239,7 @@ class EvolveGCNO(nn.Module): ...@@ -176,7 +239,7 @@ class EvolveGCNO(nn.Module):
def forward(self, g_list): def forward(self, g_list):
feature_list = [] feature_list = []
for g in g_list: for g in g_list:
feature_list.append(g.ndata['feat']) feature_list.append(g.ndata["feat"])
for i in range(self.num_layers): for i in range(self.num_layers):
W = self.gcn_weights_list[i] W = self.gcn_weights_list[i]
for j, g in enumerate(g_list): for j, g in enumerate(g_list):
...@@ -191,5 +254,7 @@ class EvolveGCNO(nn.Module): ...@@ -191,5 +254,7 @@ class EvolveGCNO(nn.Module):
# Remove the following line of code, it will become `GCN`. # Remove the following line of code, it will become `GCN`.
W = self.recurrent_layers[i](W) W = self.recurrent_layers[i](W)
feature_list[j] = self.gnn_convs[i](g, feature_list[j], weight=W) feature_list[j] = self.gnn_convs[i](
g, feature_list[j], weight=W
)
return self.mlp(feature_list[-1]) return self.mlp(feature_list[-1])
import argparse import argparse
import time import time
import dgl
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from dataset import EllipticDataset from dataset import EllipticDataset
from model import EvolveGCNO, EvolveGCNH from model import EvolveGCNH, EvolveGCNO
from utils import Measure from utils import Measure
import dgl
def train(args, device): def train(args, device):
elliptic_dataset = EllipticDataset(raw_dir=args.raw_dir, elliptic_dataset = EllipticDataset(
processed_dir=args.processed_dir, raw_dir=args.raw_dir,
self_loop=True, processed_dir=args.processed_dir,
reverse_edge=True) self_loop=True,
reverse_edge=True,
)
g, node_mask_by_time = elliptic_dataset.process() g, node_mask_by_time = elliptic_dataset.process()
num_classes = elliptic_dataset.num_classes num_classes = elliptic_dataset.num_classes
...@@ -24,18 +27,21 @@ def train(args, device): ...@@ -24,18 +27,21 @@ def train(args, device):
# we add self loop edge when we construct full graph, not here # we add self loop edge when we construct full graph, not here
node_subgraph = dgl.node_subgraph(graph=g, nodes=node_mask_by_time[i]) node_subgraph = dgl.node_subgraph(graph=g, nodes=node_mask_by_time[i])
cached_subgraph.append(node_subgraph.to(device)) cached_subgraph.append(node_subgraph.to(device))
valid_node_mask = node_subgraph.ndata['label'] >= 0 valid_node_mask = node_subgraph.ndata["label"] >= 0
cached_labeled_node_mask.append(valid_node_mask) cached_labeled_node_mask.append(valid_node_mask)
if args.model == 'EvolveGCN-O': if args.model == "EvolveGCN-O":
model = EvolveGCNO(in_feats=int(g.ndata['feat'].shape[1]), model = EvolveGCNO(
n_hidden=args.n_hidden, in_feats=int(g.ndata["feat"].shape[1]),
num_layers=args.n_layers) n_hidden=args.n_hidden,
elif args.model == 'EvolveGCN-H': num_layers=args.n_layers,
model = EvolveGCNH(in_feats=int(g.ndata['feat'].shape[1]), )
num_layers=args.n_layers) elif args.model == "EvolveGCN-H":
model = EvolveGCNH(
in_feats=int(g.ndata["feat"].shape[1]), num_layers=args.n_layers
)
else: else:
return NotImplementedError('Unsupported model {}'.format(args.model)) return NotImplementedError("Unsupported model {}".format(args.model))
model = model.to(device) model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
...@@ -45,23 +51,35 @@ def train(args, device): ...@@ -45,23 +51,35 @@ def train(args, device):
valid_max_index = 35 valid_max_index = 35
test_max_index = 48 test_max_index = 48
time_window_size = args.n_hist_steps time_window_size = args.n_hist_steps
loss_class_weight = [float(w) for w in args.loss_class_weight.split(',')] loss_class_weight = [float(w) for w in args.loss_class_weight.split(",")]
loss_class_weight = torch.Tensor(loss_class_weight).to(device) loss_class_weight = torch.Tensor(loss_class_weight).to(device)
train_measure = Measure(num_classes=num_classes, target_class=args.eval_class_id) train_measure = Measure(
valid_measure = Measure(num_classes=num_classes, target_class=args.eval_class_id) num_classes=num_classes, target_class=args.eval_class_id
test_measure = Measure(num_classes=num_classes, target_class=args.eval_class_id) )
valid_measure = Measure(
num_classes=num_classes, target_class=args.eval_class_id
)
test_measure = Measure(
num_classes=num_classes, target_class=args.eval_class_id
)
test_res_f1 = 0 test_res_f1 = 0
for epoch in range(args.num_epochs): for epoch in range(args.num_epochs):
model.train() model.train()
for i in range(time_window_size, train_max_index + 1): for i in range(time_window_size, train_max_index + 1):
g_list = cached_subgraph[i - time_window_size:i + 1] g_list = cached_subgraph[i - time_window_size : i + 1]
predictions = model(g_list) predictions = model(g_list)
# get predictions which has label # get predictions which has label
predictions = predictions[cached_labeled_node_mask[i]] predictions = predictions[cached_labeled_node_mask[i]]
labels = cached_subgraph[i].ndata['label'][cached_labeled_node_mask[i]].long() labels = (
loss = F.cross_entropy(predictions, labels, weight=loss_class_weight) cached_subgraph[i]
.ndata["label"][cached_labeled_node_mask[i]]
.long()
)
loss = F.cross_entropy(
predictions, labels, weight=loss_class_weight
)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
...@@ -74,17 +92,24 @@ def train(args, device): ...@@ -74,17 +92,24 @@ def train(args, device):
# reset measures for next epoch # reset measures for next epoch
train_measure.reset_info() train_measure.reset_info()
print("Train Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}" print(
.format(epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1)) "Train Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1
)
)
# eval # eval
model.eval() model.eval()
for i in range(train_max_index + 1, valid_max_index + 1): for i in range(train_max_index + 1, valid_max_index + 1):
g_list = cached_subgraph[i - time_window_size:i + 1] g_list = cached_subgraph[i - time_window_size : i + 1]
predictions = model(g_list) predictions = model(g_list)
# get node predictions which has label # get node predictions which has label
predictions = predictions[cached_labeled_node_mask[i]] predictions = predictions[cached_labeled_node_mask[i]]
labels = cached_subgraph[i].ndata['label'][cached_labeled_node_mask[i]].long() labels = (
cached_subgraph[i]
.ndata["label"][cached_labeled_node_mask[i]]
.long()
)
valid_measure.append_measures(predictions, labels) valid_measure.append_measures(predictions, labels)
...@@ -94,30 +119,54 @@ def train(args, device): ...@@ -94,30 +119,54 @@ def train(args, device):
# reset measures for next epoch # reset measures for next epoch
valid_measure.reset_info() valid_measure.reset_info()
print("Eval Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}" print(
.format(epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1)) "Eval Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1
)
)
# early stop # early stop
if epoch - valid_measure.target_best_f1_epoch >= args.patience: if epoch - valid_measure.target_best_f1_epoch >= args.patience:
print("Best eval Epoch {}, Cur Epoch {}".format(valid_measure.target_best_f1_epoch, epoch)) print(
"Best eval Epoch {}, Cur Epoch {}".format(
valid_measure.target_best_f1_epoch, epoch
)
)
break break
# if cur valid f1 score is best, do test # if cur valid f1 score is best, do test
if epoch == valid_measure.target_best_f1_epoch: if epoch == valid_measure.target_best_f1_epoch:
print("###################Epoch {} Test###################".format(epoch)) print(
"###################Epoch {} Test###################".format(
epoch
)
)
for i in range(valid_max_index + 1, test_max_index + 1): for i in range(valid_max_index + 1, test_max_index + 1):
g_list = cached_subgraph[i - time_window_size:i + 1] g_list = cached_subgraph[i - time_window_size : i + 1]
predictions = model(g_list) predictions = model(g_list)
# get predictions which has label # get predictions which has label
predictions = predictions[cached_labeled_node_mask[i]] predictions = predictions[cached_labeled_node_mask[i]]
labels = cached_subgraph[i].ndata['label'][cached_labeled_node_mask[i]].long() labels = (
cached_subgraph[i]
.ndata["label"][cached_labeled_node_mask[i]]
.long()
)
test_measure.append_measures(predictions, labels) test_measure.append_measures(predictions, labels)
# we get each subgraph measure when testing to match fig 4 in EvolveGCN paper. # we get each subgraph measure when testing to match fig 4 in EvolveGCN paper.
cl_precisions, cl_recalls, cl_f1s = test_measure.get_each_timestamp_measure() (
for index, (sub_p, sub_r, sub_f1) in enumerate(zip(cl_precisions, cl_recalls, cl_f1s)): cl_precisions,
print(" Test | Time {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}" cl_recalls,
.format(valid_max_index + index + 2, sub_p, sub_r, sub_f1)) cl_f1s,
) = test_measure.get_each_timestamp_measure()
for index, (sub_p, sub_r, sub_f1) in enumerate(
zip(cl_precisions, cl_recalls, cl_f1s)
):
print(
" Test | Time {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
valid_max_index + index + 2, sub_p, sub_r, sub_f1
)
)
# get each epoch measure during test. # get each epoch measure during test.
cl_precision, cl_recall, cl_f1 = test_measure.get_total_measure() cl_precision, cl_recall, cl_f1 = test_measure.get_total_measure()
...@@ -127,51 +176,87 @@ def train(args, device): ...@@ -127,51 +176,87 @@ def train(args, device):
test_res_f1 = cl_f1 test_res_f1 = cl_f1
print(" Test | Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}" print(
.format(epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1)) " Test | Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1
)
)
print("Best test f1 is {}, in Epoch {}" print(
.format(test_measure.target_best_f1, test_measure.target_best_f1_epoch)) "Best test f1 is {}, in Epoch {}".format(
test_measure.target_best_f1, test_measure.target_best_f1_epoch
)
)
if test_measure.target_best_f1_epoch != valid_measure.target_best_f1_epoch: if test_measure.target_best_f1_epoch != valid_measure.target_best_f1_epoch:
print("The Epoch get best Valid measure not get the best Test measure, " print(
"please checkout the test result in Epoch {}, which f1 is {}" "The Epoch get best Valid measure not get the best Test measure, "
.format(valid_measure.target_best_f1_epoch, test_res_f1)) "please checkout the test result in Epoch {}, which f1 is {}".format(
valid_measure.target_best_f1_epoch, test_res_f1
)
)
if __name__ == "__main__": if __name__ == "__main__":
argparser = argparse.ArgumentParser("EvolveGCN") argparser = argparse.ArgumentParser("EvolveGCN")
argparser.add_argument('--model', type=str, default='EvolveGCN-O', argparser.add_argument(
help='We can choose EvolveGCN-O or EvolveGCN-H,' "--model",
'but the EvolveGCN-H performance on Elliptic dataset is not good.') type=str,
argparser.add_argument('--raw-dir', type=str, default="EvolveGCN-O",
default='/home/Elliptic/elliptic_bitcoin_dataset/', help="We can choose EvolveGCN-O or EvolveGCN-H,"
help="Dir after unzip downloaded dataset, which contains 3 csv files.") "but the EvolveGCN-H performance on Elliptic dataset is not good.",
argparser.add_argument('--processed-dir', type=str, )
default='/home/Elliptic/processed/', argparser.add_argument(
help="Dir to store processed raw data.") "--raw-dir",
argparser.add_argument('--gpu', type=int, default=0, type=str,
help="GPU device ID. Use -1 for CPU training.") default="/home/Elliptic/elliptic_bitcoin_dataset/",
argparser.add_argument('--num-epochs', type=int, default=1000) help="Dir after unzip downloaded dataset, which contains 3 csv files.",
argparser.add_argument('--n-hidden', type=int, default=256) )
argparser.add_argument('--n-layers', type=int, default=2) argparser.add_argument(
argparser.add_argument('--n-hist-steps', type=int, default=5, "--processed-dir",
help="If it is set to 5, it means in the first batch," type=str,
"we use historical data of 0-4 to predict the data of time 5.") default="/home/Elliptic/processed/",
argparser.add_argument('--lr', type=float, default=0.001) help="Dir to store processed raw data.",
argparser.add_argument('--loss-class-weight', type=str, default='0.35,0.65', )
help='Weight for loss function. Follow the official code,' argparser.add_argument(
'we need to change it to 0.25, 0.75 when use EvolveGCN-H') "--gpu",
argparser.add_argument('--eval-class-id', type=int, default=1, type=int,
help="Class type to eval. On Elliptic, type 1(illicit) is the main interest.") default=0,
argparser.add_argument('--patience', type=int, default=100, help="GPU device ID. Use -1 for CPU training.",
help="Patience for early stopping.") )
argparser.add_argument("--num-epochs", type=int, default=1000)
argparser.add_argument("--n-hidden", type=int, default=256)
argparser.add_argument("--n-layers", type=int, default=2)
argparser.add_argument(
"--n-hist-steps",
type=int,
default=5,
help="If it is set to 5, it means in the first batch,"
"we use historical data of 0-4 to predict the data of time 5.",
)
argparser.add_argument("--lr", type=float, default=0.001)
argparser.add_argument(
"--loss-class-weight",
type=str,
default="0.35,0.65",
help="Weight for loss function. Follow the official code,"
"we need to change it to 0.25, 0.75 when use EvolveGCN-H",
)
argparser.add_argument(
"--eval-class-id",
type=int,
default=1,
help="Class type to eval. On Elliptic, type 1(illicit) is the main interest.",
)
argparser.add_argument(
"--patience", type=int, default=100, help="Patience for early stopping."
)
args = argparser.parse_args() args = argparser.parse_args()
if args.gpu >= 0: if args.gpu >= 0:
device = torch.device('cuda:%d' % args.gpu) device = torch.device("cuda:%d" % args.gpu)
else: else:
device = torch.device('cpu') device = torch.device("cpu")
start_time = time.perf_counter() start_time = time.perf_counter()
train(args, device) train(args, device)
......
...@@ -31,18 +31,24 @@ class Measure(object): ...@@ -31,18 +31,24 @@ class Measure(object):
def reset_info(self): def reset_info(self):
""" """
reset info after each epoch. reset info after each epoch.
""" """
self.true_positives = {cur_class: [] for cur_class in range(self.num_classes)} self.true_positives = {
self.false_positives = {cur_class: [] for cur_class in range(self.num_classes)} cur_class: [] for cur_class in range(self.num_classes)
self.false_negatives = {cur_class: [] for cur_class in range(self.num_classes)} }
self.false_positives = {
cur_class: [] for cur_class in range(self.num_classes)
}
self.false_negatives = {
cur_class: [] for cur_class in range(self.num_classes)
}
def append_measures(self, predictions, labels): def append_measures(self, predictions, labels):
predicted_classes = predictions.argmax(dim=1) predicted_classes = predictions.argmax(dim=1)
for cl in range(self.num_classes): for cl in range(self.num_classes):
cl_indices = (labels == cl) cl_indices = labels == cl
pos = (predicted_classes == cl) pos = predicted_classes == cl
hits = (predicted_classes[cl_indices] == labels[cl_indices]) hits = predicted_classes[cl_indices] == labels[cl_indices]
tp = hits.sum() tp = hits.sum()
fn = hits.size(0) - tp fn = hits.size(0) - tp
......
import os import os
import dgl
import torch as th
import numpy as np import numpy as np
import scipy.io as sio import scipy.io as sio
import torch as th
import dgl
from dgl.data import DGLBuiltinDataset from dgl.data import DGLBuiltinDataset
from dgl.data.utils import save_graphs, load_graphs, _get_dgl_url from dgl.data.utils import _get_dgl_url, load_graphs, save_graphs
class GASDataset(DGLBuiltinDataset): class GASDataset(DGLBuiltinDataset):
file_urls = { file_urls = {"pol": "dataset/GASPOL.zip", "gos": "dataset/GASGOS.zip"}
'pol': 'dataset/GASPOL.zip',
'gos': 'dataset/GASGOS.zip'
}
def __init__(self, name, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1): def __init__(
assert name in ['gos', 'pol'], "Only supports 'gos' or 'pol'." self, name, raw_dir=None, random_seed=717, train_size=0.7, val_size=0.1
):
assert name in ["gos", "pol"], "Only supports 'gos' or 'pol'."
self.seed = random_seed self.seed = random_seed
self.train_size = train_size self.train_size = train_size
self.val_size = val_size self.val_size = val_size
url = _get_dgl_url(self.file_urls[name]) url = _get_dgl_url(self.file_urls[name])
super(GASDataset, self).__init__(name=name, super(GASDataset, self).__init__(name=name, url=url, raw_dir=raw_dir)
url=url,
raw_dir=raw_dir)
def process(self): def process(self):
"""process raw data to graph, labels and masks""" """process raw data to graph, labels and masks"""
data = sio.loadmat(os.path.join(self.raw_path, f'{self.name}_retweet_graph.mat')) data = sio.loadmat(
os.path.join(self.raw_path, f"{self.name}_retweet_graph.mat")
)
adj = data['graph'].tocoo() adj = data["graph"].tocoo()
num_edges = len(adj.row) num_edges = len(adj.row)
row, col = adj.row[:int(num_edges/2)], adj.col[:int(num_edges/2)] row, col = adj.row[: int(num_edges / 2)], adj.col[: int(num_edges / 2)]
graph = dgl.graph((np.concatenate((row, col)), np.concatenate((col, row)))) graph = dgl.graph(
news_labels = data['label'].squeeze() (np.concatenate((row, col)), np.concatenate((col, row)))
)
news_labels = data["label"].squeeze()
num_news = len(news_labels) num_news = len(news_labels)
node_feature = np.load(os.path.join(self.raw_path, f'{self.name}_node_feature.npy')) node_feature = np.load(
edge_feature = np.load(os.path.join(self.raw_path, f'{self.name}_edge_feature.npy'))[:int(num_edges/2)] os.path.join(self.raw_path, f"{self.name}_node_feature.npy")
)
edge_feature = np.load(
os.path.join(self.raw_path, f"{self.name}_edge_feature.npy")
)[: int(num_edges / 2)]
graph.ndata['feat'] = th.tensor(node_feature) graph.ndata["feat"] = th.tensor(node_feature)
graph.edata['feat'] = th.tensor(np.tile(edge_feature, (2, 1))) graph.edata["feat"] = th.tensor(np.tile(edge_feature, (2, 1)))
pos_news = news_labels.nonzero()[0] pos_news = news_labels.nonzero()[0]
edge_labels = th.zeros(num_edges) edge_labels = th.zeros(num_edges)
edge_labels[graph.in_edges(pos_news, form='eid')] = 1 edge_labels[graph.in_edges(pos_news, form="eid")] = 1
edge_labels[graph.out_edges(pos_news, form='eid')] = 1 edge_labels[graph.out_edges(pos_news, form="eid")] = 1
graph.edata['label'] = edge_labels graph.edata["label"] = edge_labels
ntypes = th.ones(graph.num_nodes(), dtype=int) ntypes = th.ones(graph.num_nodes(), dtype=int)
etypes = th.ones(graph.num_edges(), dtype=int) etypes = th.ones(graph.num_edges(), dtype=int)
ntypes[graph.nodes() < num_news] = 0 ntypes[graph.nodes() < num_news] = 0
etypes[:int(num_edges/2)] = 0 etypes[: int(num_edges / 2)] = 0
graph.ndata['_TYPE'] = ntypes graph.ndata["_TYPE"] = ntypes
graph.edata['_TYPE'] = etypes graph.edata["_TYPE"] = etypes
hg = dgl.to_heterogeneous(graph, ['v', 'u'], ['forward', 'backward']) hg = dgl.to_heterogeneous(graph, ["v", "u"], ["forward", "backward"])
self._random_split(hg, self.seed, self.train_size, self.val_size) self._random_split(hg, self.seed, self.train_size, self.val_size)
self.graph = hg self.graph = hg
def save(self): def save(self):
"""save the graph list and the labels""" """save the graph list and the labels"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
save_graphs(str(graph_path), self.graph) save_graphs(str(graph_path), self.graph)
def has_cache(self): def has_cache(self):
""" check whether there are processed data in `self.save_path` """ """check whether there are processed data in `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
return os.path.exists(graph_path) return os.path.exists(graph_path)
def load(self): def load(self):
"""load processed data from directory `self.save_path`""" """load processed data from directory `self.save_path`"""
graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin') graph_path = os.path.join(self.save_path, self.name + "_dgl_graph.bin")
graph, _ = load_graphs(str(graph_path)) graph, _ = load_graphs(str(graph_path))
self.graph = graph[0] self.graph = graph[0]
...@@ -84,7 +91,7 @@ class GASDataset(DGLBuiltinDataset): ...@@ -84,7 +91,7 @@ class GASDataset(DGLBuiltinDataset):
return 2 return 2
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph object r"""Get graph object
Parameters Parameters
---------- ----------
idx : int idx : int
...@@ -107,27 +114,29 @@ class GASDataset(DGLBuiltinDataset): ...@@ -107,27 +114,29 @@ class GASDataset(DGLBuiltinDataset):
def _random_split(self, graph, seed=717, train_size=0.7, val_size=0.1): def _random_split(self, graph, seed=717, train_size=0.7, val_size=0.1):
"""split the dataset into training set, validation set and testing set""" """split the dataset into training set, validation set and testing set"""
assert 0 <= train_size + val_size <= 1, \ assert 0 <= train_size + val_size <= 1, (
"The sum of valid training set size and validation set size " \ "The sum of valid training set size and validation set size "
"must between 0 and 1 (inclusive)." "must between 0 and 1 (inclusive)."
)
num_edges = graph.num_edges(etype='forward') num_edges = graph.num_edges(etype="forward")
index = np.arange(num_edges) index = np.arange(num_edges)
index = np.random.RandomState(seed).permutation(index) index = np.random.RandomState(seed).permutation(index)
train_idx = index[:int(train_size * num_edges)] train_idx = index[: int(train_size * num_edges)]
val_idx = index[num_edges - int(val_size * num_edges):] val_idx = index[num_edges - int(val_size * num_edges) :]
test_idx = index[int(train_size * num_edges):num_edges - int(val_size * num_edges)] test_idx = index[
int(train_size * num_edges) : num_edges - int(val_size * num_edges)
]
train_mask = np.zeros(num_edges, dtype=np.bool) train_mask = np.zeros(num_edges, dtype=np.bool)
val_mask = np.zeros(num_edges, dtype=np.bool) val_mask = np.zeros(num_edges, dtype=np.bool)
test_mask = np.zeros(num_edges, dtype=np.bool) test_mask = np.zeros(num_edges, dtype=np.bool)
train_mask[train_idx] = True train_mask[train_idx] = True
val_mask[val_idx] = True val_mask[val_idx] = True
test_mask[test_idx] = True test_mask[test_idx] = True
graph.edges['forward'].data['train_mask'] = th.tensor(train_mask) graph.edges["forward"].data["train_mask"] = th.tensor(train_mask)
graph.edges['forward'].data['val_mask'] = th.tensor(val_mask) graph.edges["forward"].data["val_mask"] = th.tensor(val_mask)
graph.edges['forward'].data['test_mask'] = th.tensor(test_mask) graph.edges["forward"].data["test_mask"] = th.tensor(test_mask)
graph.edges['backward'].data['train_mask'] = th.tensor(train_mask) graph.edges["backward"].data["train_mask"] = th.tensor(train_mask)
graph.edges['backward'].data['val_mask'] = th.tensor(val_mask) graph.edges["backward"].data["val_mask"] = th.tensor(val_mask)
graph.edges['backward'].data['test_mask'] = th.tensor(test_mask) graph.edges["backward"].data["test_mask"] = th.tensor(test_mask)
import argparse import argparse
import torch as th import torch as th
import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim
from dataloader import GASDataset from dataloader import GASDataset
from model import GAS from model import GAS
from sklearn.metrics import f1_score, roc_auc_score, precision_recall_curve from sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score
def main(args): def main(args):
...@@ -15,25 +16,25 @@ def main(args): ...@@ -15,25 +16,25 @@ def main(args):
# check cuda # check cuda
if args.gpu >= 0 and th.cuda.is_available(): if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu) device = "cuda:{}".format(args.gpu)
else: else:
device = 'cpu' device = "cpu"
# binary classification # binary classification
num_classes = dataset.num_classes num_classes = dataset.num_classes
# retrieve labels of ground truth # retrieve labels of ground truth
labels = graph.edges['forward'].data['label'].to(device).long() labels = graph.edges["forward"].data["label"].to(device).long()
# Extract node features # Extract node features
e_feat = graph.edges['forward'].data['feat'].to(device) e_feat = graph.edges["forward"].data["feat"].to(device)
u_feat = graph.nodes['u'].data['feat'].to(device) u_feat = graph.nodes["u"].data["feat"].to(device)
v_feat = graph.nodes['v'].data['feat'].to(device) v_feat = graph.nodes["v"].data["feat"].to(device)
# retrieve masks for train/validation/test # retrieve masks for train/validation/test
train_mask = graph.edges['forward'].data['train_mask'] train_mask = graph.edges["forward"].data["train_mask"]
val_mask = graph.edges['forward'].data['val_mask'] val_mask = graph.edges["forward"].data["val_mask"]
test_mask = graph.edges['forward'].data['test_mask'] test_mask = graph.edges["forward"].data["test_mask"]
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device) train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device) val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)
...@@ -42,22 +43,26 @@ def main(args): ...@@ -42,22 +43,26 @@ def main(args):
graph = graph.to(device) graph = graph.to(device)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
model = GAS(e_in_dim=e_feat.shape[-1], model = GAS(
u_in_dim=u_feat.shape[-1], e_in_dim=e_feat.shape[-1],
v_in_dim=v_feat.shape[-1], u_in_dim=u_feat.shape[-1],
e_hid_dim=args.e_hid_dim, v_in_dim=v_feat.shape[-1],
u_hid_dim=args.u_hid_dim, e_hid_dim=args.e_hid_dim,
v_hid_dim=args.v_hid_dim, u_hid_dim=args.u_hid_dim,
out_dim=num_classes, v_hid_dim=args.v_hid_dim,
num_layers=args.num_layers, out_dim=num_classes,
dropout=args.dropout, num_layers=args.num_layers,
activation=F.relu) dropout=args.dropout,
activation=F.relu,
)
model = model.to(device) model = model.to(device)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
loss_fn = th.nn.CrossEntropyLoss() loss_fn = th.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer = optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# Step 4: training epochs =============================================================== # # Step 4: training epochs =============================================================== #
for epoch in range(args.max_epoch): for epoch in range(args.max_epoch):
...@@ -67,16 +72,28 @@ def main(args): ...@@ -67,16 +72,28 @@ def main(args):
# compute loss # compute loss
tr_loss = loss_fn(logits[train_idx], labels[train_idx]) tr_loss = loss_fn(logits[train_idx], labels[train_idx])
tr_f1 = f1_score(labels[train_idx].cpu(), logits[train_idx].argmax(dim=1).cpu()) tr_f1 = f1_score(
tr_auc = roc_auc_score(labels[train_idx].cpu(), logits[train_idx][:, 1].detach().cpu()) labels[train_idx].cpu(), logits[train_idx].argmax(dim=1).cpu()
tr_pre, tr_re, _ = precision_recall_curve(labels[train_idx].cpu(), logits[train_idx][:, 1].detach().cpu()) )
tr_auc = roc_auc_score(
labels[train_idx].cpu(), logits[train_idx][:, 1].detach().cpu()
)
tr_pre, tr_re, _ = precision_recall_curve(
labels[train_idx].cpu(), logits[train_idx][:, 1].detach().cpu()
)
tr_rap = tr_re[tr_pre > args.precision].max() tr_rap = tr_re[tr_pre > args.precision].max()
# validation # validation
valid_loss = loss_fn(logits[val_idx], labels[val_idx]) valid_loss = loss_fn(logits[val_idx], labels[val_idx])
valid_f1 = f1_score(labels[val_idx].cpu(), logits[val_idx].argmax(dim=1).cpu()) valid_f1 = f1_score(
valid_auc = roc_auc_score(labels[val_idx].cpu(), logits[val_idx][:, 1].detach().cpu()) labels[val_idx].cpu(), logits[val_idx].argmax(dim=1).cpu()
valid_pre, valid_re, _ = precision_recall_curve(labels[val_idx].cpu(), logits[val_idx][:, 1].detach().cpu()) )
valid_auc = roc_auc_score(
labels[val_idx].cpu(), logits[val_idx][:, 1].detach().cpu()
)
valid_pre, valid_re, _ = precision_recall_curve(
labels[val_idx].cpu(), logits[val_idx][:, 1].detach().cpu()
)
valid_rap = valid_re[valid_pre > args.precision].max() valid_rap = valid_re[valid_pre > args.precision].max()
# backward # backward
...@@ -85,9 +102,20 @@ def main(args): ...@@ -85,9 +102,20 @@ def main(args):
optimizer.step() optimizer.step()
# Print out performance # Print out performance
print("In epoch {}, Train R@P: {:.4f} | Train F1: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; " print(
"Valid R@P: {:.4f} | Valid F1: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}". "In epoch {}, Train R@P: {:.4f} | Train F1: {:.4f} | Train AUC: {:.4f} | Train Loss: {:.4f}; "
format(epoch, tr_rap, tr_f1, tr_auc, tr_loss.item(), valid_rap, valid_f1, valid_auc, valid_loss.item())) "Valid R@P: {:.4f} | Valid F1: {:.4f} | Valid AUC: {:.4f} | Valid loss: {:.4f}".format(
epoch,
tr_rap,
tr_f1,
tr_auc,
tr_loss.item(),
valid_rap,
valid_f1,
valid_auc,
valid_loss.item(),
)
)
# Test after all epoch # Test after all epoch
model.eval() model.eval()
...@@ -97,28 +125,77 @@ def main(args): ...@@ -97,28 +125,77 @@ def main(args):
# compute loss # compute loss
test_loss = loss_fn(logits[test_idx], labels[test_idx]) test_loss = loss_fn(logits[test_idx], labels[test_idx])
test_f1 = f1_score(labels[test_idx].cpu(), logits[test_idx].argmax(dim=1).cpu()) test_f1 = f1_score(
test_auc = roc_auc_score(labels[test_idx].cpu(), logits[test_idx][:, 1].detach().cpu()) labels[test_idx].cpu(), logits[test_idx].argmax(dim=1).cpu()
test_pre, test_re, _ = precision_recall_curve(labels[test_idx].cpu(), logits[test_idx][:, 1].detach().cpu()) )
test_auc = roc_auc_score(
labels[test_idx].cpu(), logits[test_idx][:, 1].detach().cpu()
)
test_pre, test_re, _ = precision_recall_curve(
labels[test_idx].cpu(), logits[test_idx][:, 1].detach().cpu()
)
test_rap = test_re[test_pre > args.precision].max() test_rap = test_re[test_pre > args.precision].max()
print("Test R@P: {:.4f} | Test F1: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}". print(
format(test_rap, test_f1, test_auc, test_loss.item())) "Test R@P: {:.4f} | Test F1: {:.4f} | Test AUC: {:.4f} | Test loss: {:.4f}".format(
test_rap, test_f1, test_auc, test_loss.item()
)
if __name__ == '__main__': )
parser = argparse.ArgumentParser(description='GCN-based Anti-Spam Model')
parser.add_argument("--dataset", type=str, default="pol", help="'pol', or 'gos'")
parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.") if __name__ == "__main__":
parser.add_argument("--e_hid_dim", type=int, default=128, help="Hidden layer dimension for edges") parser = argparse.ArgumentParser(description="GCN-based Anti-Spam Model")
parser.add_argument("--u_hid_dim", type=int, default=128, help="Hidden layer dimension for source nodes") parser.add_argument(
parser.add_argument("--v_hid_dim", type=int, default=128, help="Hidden layer dimension for destination nodes") "--dataset", type=str, default="pol", help="'pol', or 'gos'"
parser.add_argument("--num_layers", type=int, default=2, help="Number of GCN layers") )
parser.add_argument("--max_epoch", type=int, default=100, help="The max number of epochs. Default: 100") parser.add_argument(
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate. Default: 1e-3") "--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU."
parser.add_argument("--dropout", type=float, default=0.0, help="Dropout rate. Default: 0.0") )
parser.add_argument("--weight_decay", type=float, default=5e-4, help="Weight Decay. Default: 0.0005") parser.add_argument(
parser.add_argument("--precision", type=float, default=0.9, help="The value p in recall@p precision. Default: 0.9") "--e_hid_dim",
type=int,
default=128,
help="Hidden layer dimension for edges",
)
parser.add_argument(
"--u_hid_dim",
type=int,
default=128,
help="Hidden layer dimension for source nodes",
)
parser.add_argument(
"--v_hid_dim",
type=int,
default=128,
help="Hidden layer dimension for destination nodes",
)
parser.add_argument(
"--num_layers", type=int, default=2, help="Number of GCN layers"
)
parser.add_argument(
"--max_epoch",
type=int,
default=100,
help="The max number of epochs. Default: 100",
)
parser.add_argument(
"--lr", type=float, default=0.001, help="Learning rate. Default: 1e-3"
)
parser.add_argument(
"--dropout", type=float, default=0.0, help="Dropout rate. Default: 0.0"
)
parser.add_argument(
"--weight_decay",
type=float,
default=5e-4,
help="Weight Decay. Default: 0.0005",
)
parser.add_argument(
"--precision",
type=float,
default=0.9,
help="The value p in recall@p precision. Default: 0.9",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import torch as th
import torch.nn as nn import torch.nn as nn
import dgl.function as fn import dgl.function as fn
import torch as th
from dgl.nn.functional import edge_softmax from dgl.nn.functional import edge_softmax
...@@ -10,33 +11,35 @@ class MLP(nn.Module): ...@@ -10,33 +11,35 @@ class MLP(nn.Module):
self.W = nn.Linear(in_dim, out_dim) self.W = nn.Linear(in_dim, out_dim)
def apply_edges(self, edges): def apply_edges(self, edges):
h_e = edges.data['h'] h_e = edges.data["h"]
h_u = edges.src['h'] h_u = edges.src["h"]
h_v = edges.dst['h'] h_v = edges.dst["h"]
score = self.W(th.cat([h_e, h_u, h_v], -1)) score = self.W(th.cat([h_e, h_u, h_v], -1))
return {'score': score} return {"score": score}
def forward(self, g, e_feat, u_feat, v_feat): def forward(self, g, e_feat, u_feat, v_feat):
with g.local_scope(): with g.local_scope():
g.edges['forward'].data['h'] = e_feat g.edges["forward"].data["h"] = e_feat
g.nodes['u'].data['h'] = u_feat g.nodes["u"].data["h"] = u_feat
g.nodes['v'].data['h'] = v_feat g.nodes["v"].data["h"] = v_feat
g.apply_edges(self.apply_edges, etype="forward") g.apply_edges(self.apply_edges, etype="forward")
return g.edges['forward'].data['score'] return g.edges["forward"].data["score"]
class GASConv(nn.Module): class GASConv(nn.Module):
"""One layer of GAS.""" """One layer of GAS."""
def __init__(self, def __init__(
e_in_dim, self,
u_in_dim, e_in_dim,
v_in_dim, u_in_dim,
e_out_dim, v_in_dim,
u_out_dim, e_out_dim,
v_out_dim, u_out_dim,
activation=None, v_out_dim,
dropout=0): activation=None,
dropout=0,
):
super(GASConv, self).__init__() super(GASConv, self).__init__()
self.activation = activation self.activation = activation
...@@ -61,47 +64,82 @@ class GASConv(nn.Module): ...@@ -61,47 +64,82 @@ class GASConv(nn.Module):
def forward(self, g, e_feat, u_feat, v_feat): def forward(self, g, e_feat, u_feat, v_feat):
with g.local_scope(): with g.local_scope():
g.nodes['u'].data['h'] = u_feat g.nodes["u"].data["h"] = u_feat
g.nodes['v'].data['h'] = v_feat g.nodes["v"].data["h"] = v_feat
g.edges['forward'].data['h'] = e_feat g.edges["forward"].data["h"] = e_feat
g.edges['backward'].data['h'] = e_feat g.edges["backward"].data["h"] = e_feat
# formula 3 and 4 (optimized implementation to save memory) # formula 3 and 4 (optimized implementation to save memory)
g.nodes["u"].data.update({'he_u': self.u_linear(u_feat)}) g.nodes["u"].data.update({"he_u": self.u_linear(u_feat)})
g.nodes["v"].data.update({'he_v': self.v_linear(v_feat)}) g.nodes["v"].data.update({"he_v": self.v_linear(v_feat)})
g.edges["forward"].data.update({'he_e': self.e_linear(e_feat)}) g.edges["forward"].data.update({"he_e": self.e_linear(e_feat)})
g.apply_edges(lambda edges: {'he': edges.data['he_e'] + edges.src['he_u'] + edges.dst['he_v']}, etype='forward') g.apply_edges(
he = g.edges["forward"].data['he'] lambda edges: {
"he": edges.data["he_e"]
+ edges.src["he_u"]
+ edges.dst["he_v"]
},
etype="forward",
)
he = g.edges["forward"].data["he"]
if self.activation is not None: if self.activation is not None:
he = self.activation(he) he = self.activation(he)
# formula 6 # formula 6
g.apply_edges(lambda edges: {'h_ve': th.cat([edges.src['h'], edges.data['h']], -1)}, etype='backward') g.apply_edges(
g.apply_edges(lambda edges: {'h_ue': th.cat([edges.src['h'], edges.data['h']], -1)}, etype='forward') lambda edges: {
"h_ve": th.cat([edges.src["h"], edges.data["h"]], -1)
},
etype="backward",
)
g.apply_edges(
lambda edges: {
"h_ue": th.cat([edges.src["h"], edges.data["h"]], -1)
},
etype="forward",
)
# formula 7, self-attention # formula 7, self-attention
g.nodes['u'].data['h_att_u'] = self.W_ATTN_u(u_feat) g.nodes["u"].data["h_att_u"] = self.W_ATTN_u(u_feat)
g.nodes['v'].data['h_att_v'] = self.W_ATTN_v(v_feat) g.nodes["v"].data["h_att_v"] = self.W_ATTN_v(v_feat)
# Step 1: dot product # Step 1: dot product
g.apply_edges(fn.e_dot_v('h_ve', 'h_att_u', 'edotv'), etype='backward') g.apply_edges(
g.apply_edges(fn.e_dot_v('h_ue', 'h_att_v', 'edotv'), etype='forward') fn.e_dot_v("h_ve", "h_att_u", "edotv"), etype="backward"
)
g.apply_edges(
fn.e_dot_v("h_ue", "h_att_v", "edotv"), etype="forward"
)
# Step 2. softmax # Step 2. softmax
g.edges['backward'].data['sfm'] = edge_softmax(g['backward'], g.edges['backward'].data['edotv']) g.edges["backward"].data["sfm"] = edge_softmax(
g.edges['forward'].data['sfm'] = edge_softmax(g['forward'], g.edges['forward'].data['edotv']) g["backward"], g.edges["backward"].data["edotv"]
)
g.edges["forward"].data["sfm"] = edge_softmax(
g["forward"], g.edges["forward"].data["edotv"]
)
# Step 3. Broadcast softmax value to each edge, and then attention is done # Step 3. Broadcast softmax value to each edge, and then attention is done
g.apply_edges(lambda edges: {'attn': edges.data['h_ve'] * edges.data['sfm']}, etype='backward') g.apply_edges(
g.apply_edges(lambda edges: {'attn': edges.data['h_ue'] * edges.data['sfm']}, etype='forward') lambda edges: {"attn": edges.data["h_ve"] * edges.data["sfm"]},
etype="backward",
)
g.apply_edges(
lambda edges: {"attn": edges.data["h_ue"] * edges.data["sfm"]},
etype="forward",
)
# Step 4. Aggregate attention to dst,user nodes, so formula 7 is done # Step 4. Aggregate attention to dst,user nodes, so formula 7 is done
g.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_u'), etype='backward') g.update_all(
g.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_v'), etype='forward') fn.copy_e("attn", "m"), fn.sum("m", "agg_u"), etype="backward"
)
g.update_all(
fn.copy_e("attn", "m"), fn.sum("m", "agg_v"), etype="forward"
)
# formula 5 # formula 5
h_nu = self.W_u(g.nodes['u'].data['agg_u']) h_nu = self.W_u(g.nodes["u"].data["agg_u"])
h_nv = self.W_v(g.nodes['v'].data['agg_v']) h_nv = self.W_v(g.nodes["v"].data["agg_v"])
if self.activation is not None: if self.activation is not None:
h_nu = self.activation(h_nu) h_nu = self.activation(h_nu)
h_nv = self.activation(h_nv) h_nv = self.activation(h_nv)
...@@ -119,17 +157,19 @@ class GASConv(nn.Module): ...@@ -119,17 +157,19 @@ class GASConv(nn.Module):
class GAS(nn.Module): class GAS(nn.Module):
def __init__(self, def __init__(
e_in_dim, self,
u_in_dim, e_in_dim,
v_in_dim, u_in_dim,
e_hid_dim, v_in_dim,
u_hid_dim, e_hid_dim,
v_hid_dim, u_hid_dim,
out_dim, v_hid_dim,
num_layers=2, out_dim,
dropout=0.0, num_layers=2,
activation=None): dropout=0.0,
activation=None,
):
super(GAS, self).__init__() super(GAS, self).__init__()
self.e_in_dim = e_in_dim self.e_in_dim = e_in_dim
self.u_in_dim = u_in_dim self.u_in_dim = u_in_dim
...@@ -145,25 +185,33 @@ class GAS(nn.Module): ...@@ -145,25 +185,33 @@ class GAS(nn.Module):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# Input layer # Input layer
self.layers.append(GASConv(self.e_in_dim, self.layers.append(
self.u_in_dim, GASConv(
self.v_in_dim, self.e_in_dim,
self.e_hid_dim, self.u_in_dim,
self.u_hid_dim, self.v_in_dim,
self.v_hid_dim, self.e_hid_dim,
activation=self.activation, self.u_hid_dim,
dropout=self.dropout)) self.v_hid_dim,
activation=self.activation,
dropout=self.dropout,
)
)
# Hidden layers with n - 1 CompGraphConv layers # Hidden layers with n - 1 CompGraphConv layers
for i in range(self.num_layer - 1): for i in range(self.num_layer - 1):
self.layers.append(GASConv(self.e_hid_dim, self.layers.append(
self.u_hid_dim, GASConv(
self.v_hid_dim, self.e_hid_dim,
self.e_hid_dim, self.u_hid_dim,
self.u_hid_dim, self.v_hid_dim,
self.v_hid_dim, self.e_hid_dim,
activation=self.activation, self.u_hid_dim,
dropout=self.dropout)) self.v_hid_dim,
activation=self.activation,
dropout=self.dropout,
)
)
def forward(self, graph, e_feat, u_feat, v_feat): def forward(self, graph, e_feat, u_feat, v_feat):
# For full graph training, directly use the graph # For full graph training, directly use the graph
......
import torch as th
import torch.nn as nn import torch.nn as nn
import dgl.function as fn import dgl.function as fn
import torch as th
from dgl.nn.functional import edge_softmax from dgl.nn.functional import edge_softmax
...@@ -10,33 +11,35 @@ class MLP(nn.Module): ...@@ -10,33 +11,35 @@ class MLP(nn.Module):
self.W = nn.Linear(in_dim, out_dim) self.W = nn.Linear(in_dim, out_dim)
def apply_edges(self, edges): def apply_edges(self, edges):
h_e = edges.data['h'] h_e = edges.data["h"]
h_u = edges.src['h'] h_u = edges.src["h"]
h_v = edges.dst['h'] h_v = edges.dst["h"]
score = self.W(th.cat([h_e, h_u, h_v], -1)) score = self.W(th.cat([h_e, h_u, h_v], -1))
return {'score': score} return {"score": score}
def forward(self, g, e_feat, u_feat, v_feat): def forward(self, g, e_feat, u_feat, v_feat):
with g.local_scope(): with g.local_scope():
g.edges['forward'].data['h'] = e_feat g.edges["forward"].data["h"] = e_feat
g.nodes['u'].data['h'] = u_feat g.nodes["u"].data["h"] = u_feat
g.nodes['v'].data['h'] = v_feat g.nodes["v"].data["h"] = v_feat
g.apply_edges(self.apply_edges, etype="forward") g.apply_edges(self.apply_edges, etype="forward")
return g.edges['forward'].data['score'] return g.edges["forward"].data["score"]
class GASConv(nn.Module): class GASConv(nn.Module):
"""One layer of GAS.""" """One layer of GAS."""
def __init__(self, def __init__(
e_in_dim, self,
u_in_dim, e_in_dim,
v_in_dim, u_in_dim,
e_out_dim, v_in_dim,
u_out_dim, e_out_dim,
v_out_dim, u_out_dim,
activation=None, v_out_dim,
dropout=0): activation=None,
dropout=0,
):
super(GASConv, self).__init__() super(GASConv, self).__init__()
self.activation = activation self.activation = activation
...@@ -60,57 +63,107 @@ class GASConv(nn.Module): ...@@ -60,57 +63,107 @@ class GASConv(nn.Module):
self.Vv = nn.Linear(v_in_dim, v_out_dim - nv_dim) self.Vv = nn.Linear(v_in_dim, v_out_dim - nv_dim)
def forward(self, g, f_feat, b_feat, u_feat, v_feat): def forward(self, g, f_feat, b_feat, u_feat, v_feat):
g.srcnodes['u'].data['h'] = u_feat g.srcnodes["u"].data["h"] = u_feat
g.srcnodes['v'].data['h'] = v_feat g.srcnodes["v"].data["h"] = v_feat
g.dstnodes['u'].data['h'] = u_feat[:g.number_of_dst_nodes(ntype='u')] g.dstnodes["u"].data["h"] = u_feat[: g.number_of_dst_nodes(ntype="u")]
g.dstnodes['v'].data['h'] = v_feat[:g.number_of_dst_nodes(ntype='v')] g.dstnodes["v"].data["h"] = v_feat[: g.number_of_dst_nodes(ntype="v")]
g.edges['forward'].data['h'] = f_feat g.edges["forward"].data["h"] = f_feat
g.edges['backward'].data['h'] = b_feat g.edges["backward"].data["h"] = b_feat
# formula 3 and 4 (optimized implementation to save memory) # formula 3 and 4 (optimized implementation to save memory)
g.srcnodes["u"].data.update({'he_u': self.u_linear(g.srcnodes['u'].data['h'])}) g.srcnodes["u"].data.update(
g.srcnodes["v"].data.update({'he_v': self.v_linear(g.srcnodes['v'].data['h'])}) {"he_u": self.u_linear(g.srcnodes["u"].data["h"])}
g.dstnodes["u"].data.update({'he_u': self.u_linear(g.dstnodes['u'].data['h'])}) )
g.dstnodes["v"].data.update({'he_v': self.v_linear(g.dstnodes['v'].data['h'])}) g.srcnodes["v"].data.update(
g.edges["forward"].data.update({'he_e': self.e_linear(f_feat)}) {"he_v": self.v_linear(g.srcnodes["v"].data["h"])}
g.edges["backward"].data.update({'he_e': self.e_linear(b_feat)}) )
g.apply_edges(lambda edges: {'he': edges.data['he_e'] + edges.dst['he_u'] + edges.src['he_v']}, etype='backward') g.dstnodes["u"].data.update(
g.apply_edges(lambda edges: {'he': edges.data['he_e'] + edges.src['he_u'] + edges.dst['he_v']}, etype='forward') {"he_u": self.u_linear(g.dstnodes["u"].data["h"])}
hf = g.edges["forward"].data['he'] )
hb = g.edges["backward"].data['he'] g.dstnodes["v"].data.update(
{"he_v": self.v_linear(g.dstnodes["v"].data["h"])}
)
g.edges["forward"].data.update({"he_e": self.e_linear(f_feat)})
g.edges["backward"].data.update({"he_e": self.e_linear(b_feat)})
g.apply_edges(
lambda edges: {
"he": edges.data["he_e"] + edges.dst["he_u"] + edges.src["he_v"]
},
etype="backward",
)
g.apply_edges(
lambda edges: {
"he": edges.data["he_e"] + edges.src["he_u"] + edges.dst["he_v"]
},
etype="forward",
)
hf = g.edges["forward"].data["he"]
hb = g.edges["backward"].data["he"]
if self.activation is not None: if self.activation is not None:
hf = self.activation(hf) hf = self.activation(hf)
hb = self.activation(hb) hb = self.activation(hb)
# formula 6 # formula 6
g.apply_edges(lambda edges: {'h_ve': th.cat([edges.src['h'], edges.data['h']], -1)}, etype='backward') g.apply_edges(
g.apply_edges(lambda edges: {'h_ue': th.cat([edges.src['h'], edges.data['h']], -1)}, etype='forward') lambda edges: {
"h_ve": th.cat([edges.src["h"], edges.data["h"]], -1)
},
etype="backward",
)
g.apply_edges(
lambda edges: {
"h_ue": th.cat([edges.src["h"], edges.data["h"]], -1)
},
etype="forward",
)
# formula 7, self-attention # formula 7, self-attention
g.srcnodes['u'].data['h_att_u'] = self.W_ATTN_u(g.srcnodes['u'].data['h']) g.srcnodes["u"].data["h_att_u"] = self.W_ATTN_u(
g.srcnodes['v'].data['h_att_v'] = self.W_ATTN_v(g.srcnodes['v'].data['h']) g.srcnodes["u"].data["h"]
g.dstnodes['u'].data['h_att_u'] = self.W_ATTN_u(g.dstnodes['u'].data['h']) )
g.dstnodes['v'].data['h_att_v'] = self.W_ATTN_v(g.dstnodes['v'].data['h']) g.srcnodes["v"].data["h_att_v"] = self.W_ATTN_v(
g.srcnodes["v"].data["h"]
)
g.dstnodes["u"].data["h_att_u"] = self.W_ATTN_u(
g.dstnodes["u"].data["h"]
)
g.dstnodes["v"].data["h_att_v"] = self.W_ATTN_v(
g.dstnodes["v"].data["h"]
)
# Step 1: dot product # Step 1: dot product
g.apply_edges(fn.e_dot_v('h_ve', 'h_att_u', 'edotv'), etype='backward') g.apply_edges(fn.e_dot_v("h_ve", "h_att_u", "edotv"), etype="backward")
g.apply_edges(fn.e_dot_v('h_ue', 'h_att_v', 'edotv'), etype='forward') g.apply_edges(fn.e_dot_v("h_ue", "h_att_v", "edotv"), etype="forward")
# Step 2. softmax # Step 2. softmax
g.edges['backward'].data['sfm'] = edge_softmax(g['backward'], g.edges['backward'].data['edotv']) g.edges["backward"].data["sfm"] = edge_softmax(
g.edges['forward'].data['sfm'] = edge_softmax(g['forward'], g.edges['forward'].data['edotv']) g["backward"], g.edges["backward"].data["edotv"]
)
g.edges["forward"].data["sfm"] = edge_softmax(
g["forward"], g.edges["forward"].data["edotv"]
)
# Step 3. Broadcast softmax value to each edge, and then attention is done # Step 3. Broadcast softmax value to each edge, and then attention is done
g.apply_edges(lambda edges: {'attn': edges.data['h_ve'] * edges.data['sfm']}, etype='backward') g.apply_edges(
g.apply_edges(lambda edges: {'attn': edges.data['h_ue'] * edges.data['sfm']}, etype='forward') lambda edges: {"attn": edges.data["h_ve"] * edges.data["sfm"]},
etype="backward",
)
g.apply_edges(
lambda edges: {"attn": edges.data["h_ue"] * edges.data["sfm"]},
etype="forward",
)
# Step 4. Aggregate attention to dst,user nodes, so formula 7 is done # Step 4. Aggregate attention to dst,user nodes, so formula 7 is done
g.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_u'), etype='backward') g.update_all(
g.update_all(fn.copy_e('attn', 'm'), fn.sum('m', 'agg_v'), etype='forward') fn.copy_e("attn", "m"), fn.sum("m", "agg_u"), etype="backward"
)
g.update_all(
fn.copy_e("attn", "m"), fn.sum("m", "agg_v"), etype="forward"
)
# formula 5 # formula 5
h_nu = self.W_u(g.dstnodes['u'].data['agg_u']) h_nu = self.W_u(g.dstnodes["u"].data["agg_u"])
h_nv = self.W_v(g.dstnodes['v'].data['agg_v']) h_nv = self.W_v(g.dstnodes["v"].data["agg_v"])
if self.activation is not None: if self.activation is not None:
h_nu = self.activation(h_nu) h_nu = self.activation(h_nu)
h_nv = self.activation(h_nv) h_nv = self.activation(h_nv)
...@@ -122,24 +175,26 @@ class GASConv(nn.Module): ...@@ -122,24 +175,26 @@ class GASConv(nn.Module):
h_nv = self.dropout(h_nv) h_nv = self.dropout(h_nv)
# formula 8 # formula 8
hu = th.cat([self.Vu(g.dstnodes['u'].data['h']), h_nu], -1) hu = th.cat([self.Vu(g.dstnodes["u"].data["h"]), h_nu], -1)
hv = th.cat([self.Vv(g.dstnodes['v'].data['h']), h_nv], -1) hv = th.cat([self.Vv(g.dstnodes["v"].data["h"]), h_nv], -1)
return hf, hb, hu, hv return hf, hb, hu, hv
class GAS(nn.Module): class GAS(nn.Module):
def __init__(self, def __init__(
e_in_dim, self,
u_in_dim, e_in_dim,
v_in_dim, u_in_dim,
e_hid_dim, v_in_dim,
u_hid_dim, e_hid_dim,
v_hid_dim, u_hid_dim,
out_dim, v_hid_dim,
num_layers=2, out_dim,
dropout=0.0, num_layers=2,
activation=None): dropout=0.0,
activation=None,
):
super(GAS, self).__init__() super(GAS, self).__init__()
self.e_in_dim = e_in_dim self.e_in_dim = e_in_dim
self.u_in_dim = u_in_dim self.u_in_dim = u_in_dim
...@@ -155,34 +210,49 @@ class GAS(nn.Module): ...@@ -155,34 +210,49 @@ class GAS(nn.Module):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# Input layer # Input layer
self.layers.append(GASConv(self.e_in_dim, self.layers.append(
self.u_in_dim, GASConv(
self.v_in_dim, self.e_in_dim,
self.e_hid_dim, self.u_in_dim,
self.u_hid_dim, self.v_in_dim,
self.v_hid_dim, self.e_hid_dim,
activation=self.activation, self.u_hid_dim,
dropout=self.dropout)) self.v_hid_dim,
activation=self.activation,
dropout=self.dropout,
)
)
# Hidden layers with n - 1 CompGraphConv layers # Hidden layers with n - 1 CompGraphConv layers
for i in range(self.num_layer - 1): for i in range(self.num_layer - 1):
self.layers.append(GASConv(self.e_hid_dim, self.layers.append(
self.u_hid_dim, GASConv(
self.v_hid_dim, self.e_hid_dim,
self.e_hid_dim, self.u_hid_dim,
self.u_hid_dim, self.v_hid_dim,
self.v_hid_dim, self.e_hid_dim,
activation=self.activation, self.u_hid_dim,
dropout=self.dropout)) self.v_hid_dim,
activation=self.activation,
dropout=self.dropout,
)
)
def forward(self, subgraph, blocks, f_feat, b_feat, u_feat, v_feat): def forward(self, subgraph, blocks, f_feat, b_feat, u_feat, v_feat):
# Forward of n layers of GAS # Forward of n layers of GAS
for layer, block in zip(self.layers, blocks): for layer, block in zip(self.layers, blocks):
f_feat, b_feat, u_feat, v_feat = layer(block, f_feat, b_feat, u_feat, v_feat = layer(
f_feat[:block.num_edges(etype='forward')], block,
b_feat[:block.num_edges(etype='backward')], f_feat[: block.num_edges(etype="forward")],
u_feat, b_feat[: block.num_edges(etype="backward")],
v_feat) u_feat,
v_feat,
)
# return the result of final prediction layer # return the result of final prediction layer
return self.predictor(subgraph, f_feat[:subgraph.num_edges(etype='forward')], u_feat, v_feat) return self.predictor(
subgraph,
f_feat[: subgraph.num_edges(etype="forward")],
u_feat,
v_feat,
)
import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.nn as dglnn import dgl.nn as dglnn
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl import AddSelfLoop from dgl import AddSelfLoop
import argparse from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self,in_size, hid_size, out_size, heads): def __init__(self, in_size, hid_size, out_size, heads):
super().__init__() super().__init__()
self.gat_layers = nn.ModuleList() self.gat_layers = nn.ModuleList()
# two-layer GAT # two-layer GAT
self.gat_layers.append(dglnn.GATConv(in_size, hid_size, heads[0], feat_drop=0.6, attn_drop=0.6, activation=F.elu)) self.gat_layers.append(
self.gat_layers.append(dglnn.GATConv(hid_size*heads[0], out_size, heads[1], feat_drop=0.6, attn_drop=0.6, activation=None)) dglnn.GATConv(
in_size,
hid_size,
heads[0],
feat_drop=0.6,
attn_drop=0.6,
activation=F.elu,
)
)
self.gat_layers.append(
dglnn.GATConv(
hid_size * heads[0],
out_size,
heads[1],
feat_drop=0.6,
attn_drop=0.6,
activation=None,
)
)
def forward(self, g, inputs): def forward(self, g, inputs):
h = inputs h = inputs
for i, layer in enumerate(self.gat_layers): for i, layer in enumerate(self.gat_layers):
h = layer(g, h) h = layer(g, h)
if i == 1: # last layer if i == 1: # last layer
h = h.mean(1) h = h.mean(1)
else: # other layer(s) else: # other layer(s)
h = h.flatten(1) h = h.flatten(1)
return h return h
def evaluate(g, features, labels, mask, model): def evaluate(g, features, labels, mask, model):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -33,7 +55,8 @@ def evaluate(g, features, labels, mask, model): ...@@ -33,7 +55,8 @@ def evaluate(g, features, labels, mask, model):
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) return correct.item() * 1.0 / len(labels)
def train(g, features, labels, masks, model): def train(g, features, labels, masks, model):
# define train/val samples, loss function and optimizer # define train/val samples, loss function and optimizer
train_mask = masks[0] train_mask = masks[0]
...@@ -41,7 +64,7 @@ def train(g, features, labels, masks, model): ...@@ -41,7 +64,7 @@ def train(g, features, labels, masks, model):
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4) optimizer = torch.optim.Adam(model.parameters(), lr=5e-3, weight_decay=5e-4)
#training loop # training loop
for epoch in range(200): for epoch in range(200):
model.train() model.train()
logits = model(g, features) logits = model(g, features)
...@@ -50,43 +73,53 @@ def train(g, features, labels, masks, model): ...@@ -50,43 +73,53 @@ def train(g, features, labels, masks, model):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
acc = evaluate(g, features, labels, val_mask, model) acc = evaluate(g, features, labels, val_mask, model)
print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} " print(
. format(epoch, loss.item(), acc)) "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
epoch, loss.item(), acc
if __name__ == '__main__': )
)
if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="cora", parser.add_argument(
help="Dataset name ('cora', 'citeseer', 'pubmed').") "--dataset",
type=str,
default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed').",
)
args = parser.parse_args() args = parser.parse_args()
print(f'Training with DGL built-in GATConv module.') print(f"Training with DGL built-in GATConv module.")
# load and preprocess dataset # load and preprocess dataset
transform = AddSelfLoop() # by default, it will first remove self-loops to prevent duplication transform = (
if args.dataset == 'cora': AddSelfLoop()
) # by default, it will first remove self-loops to prevent duplication
if args.dataset == "cora":
data = CoraGraphDataset(transform=transform) data = CoraGraphDataset(transform=transform)
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset(transform=transform) data = CiteseerGraphDataset(transform=transform)
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset(transform=transform) data = PubmedGraphDataset(transform=transform)
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
g = g.int().to(device) g = g.int().to(device)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
masks = g.ndata['train_mask'], g.ndata['val_mask'], g.ndata['test_mask'] masks = g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]
# create GAT model # create GAT model
in_size = features.shape[1] in_size = features.shape[1]
out_size = data.num_classes out_size = data.num_classes
model = GAT(in_size, 8, out_size, heads=[8,1]).to(device) model = GAT(in_size, 8, out_size, heads=[8, 1]).to(device)
# model training # model training
print('Training...') print("Training...")
train(g, features, labels, masks, model) train(g, features, labels, masks, model)
# test the model # test the model
print('Testing...') print("Testing...")
acc = evaluate(g, features, labels, masks[2], model) acc = evaluate(g, features, labels, masks[2], model)
print("Test accuracy {:.4f}".format(acc)) print("Test accuracy {:.4f}".format(acc))
...@@ -8,39 +8,75 @@ Author's code: https://github.com/tech-srl/how_attentive_are_gats ...@@ -8,39 +8,75 @@ Author's code: https://github.com/tech-srl/how_attentive_are_gats
import torch import torch
import torch.nn as nn import torch.nn as nn
from dgl.nn import GATv2Conv from dgl.nn import GATv2Conv
class GATv2(nn.Module): class GATv2(nn.Module):
def __init__(self, def __init__(
num_layers, self,
in_dim, num_layers,
num_hidden, in_dim,
num_classes, num_hidden,
heads, num_classes,
activation, heads,
feat_drop, activation,
attn_drop, feat_drop,
negative_slope, attn_drop,
residual): negative_slope,
residual,
):
super(GATv2, self).__init__() super(GATv2, self).__init__()
self.num_layers = num_layers self.num_layers = num_layers
self.gatv2_layers = nn.ModuleList() self.gatv2_layers = nn.ModuleList()
self.activation = activation self.activation = activation
# input projection (no residual) # input projection (no residual)
self.gatv2_layers.append(GATv2Conv( self.gatv2_layers.append(
in_dim, num_hidden, heads[0], GATv2Conv(
feat_drop, attn_drop, negative_slope, False, self.activation, bias=False, share_weights=True)) in_dim,
num_hidden,
heads[0],
feat_drop,
attn_drop,
negative_slope,
False,
self.activation,
bias=False,
share_weights=True,
)
)
# hidden layers # hidden layers
for l in range(1, num_layers): for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads # due to multi-head, the in_dim = num_hidden * num_heads
self.gatv2_layers.append(GATv2Conv( self.gatv2_layers.append(
num_hidden * heads[l-1], num_hidden, heads[l], GATv2Conv(
feat_drop, attn_drop, negative_slope, residual, self.activation, bias=False, share_weights=True)) num_hidden * heads[l - 1],
num_hidden,
heads[l],
feat_drop,
attn_drop,
negative_slope,
residual,
self.activation,
bias=False,
share_weights=True,
)
)
# output projection # output projection
self.gatv2_layers.append(GATv2Conv( self.gatv2_layers.append(
num_hidden * heads[-2], num_classes, heads[-1], GATv2Conv(
feat_drop, attn_drop, negative_slope, residual, None, bias=False, share_weights=True)) num_hidden * heads[-2],
num_classes,
heads[-1],
feat_drop,
attn_drop,
negative_slope,
residual,
None,
bias=False,
share_weights=True,
)
)
def forward(self, g, inputs): def forward(self, g, inputs):
h = inputs h = inputs
......
...@@ -4,16 +4,17 @@ Multiple heads are also batched together for faster training. ...@@ -4,16 +4,17 @@ Multiple heads are also batched together for faster training.
""" """
import argparse import argparse
import numpy as np
import time import time
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from gatv2 import GATv2 from gatv2 import GATv2
import dgl
from dgl.data import (CiteseerGraphDataset, CoraGraphDataset,
PubmedGraphDataset, register_data_args)
class EarlyStopping: class EarlyStopping:
def __init__(self, patience=10): def __init__(self, patience=10):
...@@ -29,7 +30,9 @@ class EarlyStopping: ...@@ -29,7 +30,9 @@ class EarlyStopping:
self.save_checkpoint(model) self.save_checkpoint(model)
elif score < self.best_score: elif score < self.best_score:
self.counter += 1 self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}') print(
f"EarlyStopping counter: {self.counter} out of {self.patience}"
)
if self.counter >= self.patience: if self.counter >= self.patience:
self.early_stop = True self.early_stop = True
else: else:
...@@ -39,8 +42,9 @@ class EarlyStopping: ...@@ -39,8 +42,9 @@ class EarlyStopping:
return self.early_stop return self.early_stop
def save_checkpoint(self, model): def save_checkpoint(self, model):
'''Saves model when validation loss decrease.''' """Saves model when validation loss decrease."""
torch.save(model.state_dict(), 'es_checkpoint.pt') torch.save(model.state_dict(), "es_checkpoint.pt")
def accuracy(logits, labels): def accuracy(logits, labels):
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
...@@ -59,14 +63,14 @@ def evaluate(g, model, features, labels, mask): ...@@ -59,14 +63,14 @@ def evaluate(g, model, features, labels, mask):
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'cora': if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset() data = PubmedGraphDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
...@@ -75,24 +79,29 @@ def main(args): ...@@ -75,24 +79,29 @@ def main(args):
cuda = True cuda = True
g = g.int().to(args.gpu) g = g.int().to(args.gpu)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
num_feats = features.shape[1] num_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
print("""----Data statistics------' print(
"""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_edges, n_classes, % (
train_mask.int().sum().item(), n_edges,
val_mask.int().sum().item(), n_classes,
test_mask.int().sum().item())) train_mask.int().sum().item(),
val_mask.int().sum().item(),
test_mask.int().sum().item(),
)
)
# add self loop # add self loop
g = dgl.remove_self_loop(g) g = dgl.remove_self_loop(g)
...@@ -100,16 +109,18 @@ def main(args): ...@@ -100,16 +109,18 @@ def main(args):
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
# create model # create model
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
model = GATv2(args.num_layers, model = GATv2(
num_feats, args.num_layers,
args.num_hidden, num_feats,
n_classes, args.num_hidden,
heads, n_classes,
F.elu, heads,
args.in_drop, F.elu,
args.attn_drop, args.in_drop,
args.negative_slope, args.attn_drop,
args.residual) args.negative_slope,
args.residual,
)
print(model) print(model)
if args.early_stop: if args.early_stop:
stopper = EarlyStopping(patience=100) stopper = EarlyStopping(patience=100)
...@@ -119,7 +130,8 @@ def main(args): ...@@ -119,7 +130,8 @@ def main(args):
# use optimizer # use optimizer
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay) model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# initialize graph # initialize graph
dur = [] dur = []
...@@ -148,50 +160,90 @@ def main(args): ...@@ -148,50 +160,90 @@ def main(args):
if stopper.step(val_acc, model): if stopper.step(val_acc, model):
break break
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" print(
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}". "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
format(epoch, np.mean(dur), loss.item(), train_acc, " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".format(
val_acc, n_edges / np.mean(dur) / 1000)) epoch,
np.mean(dur),
loss.item(),
train_acc,
val_acc,
n_edges / np.mean(dur) / 1000,
)
)
print() print()
if args.early_stop: if args.early_stop:
model.load_state_dict(torch.load('es_checkpoint.pt')) model.load_state_dict(torch.load("es_checkpoint.pt"))
acc = evaluate(g, model, features, labels, test_mask) acc = evaluate(g, model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='GAT') parser = argparse.ArgumentParser(description="GAT")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument(
help="which GPU to use. Set -1 to use CPU.") "--gpu",
parser.add_argument("--epochs", type=int, default=200, type=int,
help="number of training epochs") default=-1,
parser.add_argument("--num-heads", type=int, default=8, help="which GPU to use. Set -1 to use CPU.",
help="number of hidden attention heads") )
parser.add_argument("--num-out-heads", type=int, default=1, parser.add_argument(
help="number of output attention heads") "--epochs", type=int, default=200, help="number of training epochs"
parser.add_argument("--num-layers", type=int, default=1, )
help="number of hidden layers") parser.add_argument(
parser.add_argument("--num-hidden", type=int, default=8, "--num-heads",
help="number of hidden units") type=int,
parser.add_argument("--residual", action="store_true", default=False, default=8,
help="use residual connection") help="number of hidden attention heads",
parser.add_argument("--in-drop", type=float, default=.7, )
help="input feature dropout") parser.add_argument(
parser.add_argument("--attn-drop", type=float, default=.7, "--num-out-heads",
help="attention dropout") type=int,
parser.add_argument("--lr", type=float, default=0.005, default=1,
help="learning rate") help="number of output attention heads",
parser.add_argument('--weight-decay', type=float, default=5e-4, )
help="weight decay") parser.add_argument(
parser.add_argument('--negative-slope', type=float, default=0.2, "--num-layers", type=int, default=1, help="number of hidden layers"
help="the negative slope of leaky relu") )
parser.add_argument('--early-stop', action='store_true', default=False, parser.add_argument(
help="indicates whether to use early stop or not") "--num-hidden", type=int, default=8, help="number of hidden units"
parser.add_argument('--fastmode', action="store_true", default=False, )
help="skip re-evaluate the validation set") parser.add_argument(
"--residual",
action="store_true",
default=False,
help="use residual connection",
)
parser.add_argument(
"--in-drop", type=float, default=0.7, help="input feature dropout"
)
parser.add_argument(
"--attn-drop", type=float, default=0.7, help="attention dropout"
)
parser.add_argument("--lr", type=float, default=0.005, help="learning rate")
parser.add_argument(
"--weight-decay", type=float, default=5e-4, help="weight decay"
)
parser.add_argument(
"--negative-slope",
type=float,
default=0.2,
help="the negative slope of leaky relu",
)
parser.add_argument(
"--early-stop",
action="store_true",
default=False,
help="indicates whether to use early stop or not",
)
parser.add_argument(
"--fastmode",
action="store_true",
default=False,
help="skip re-evaluate the validation set",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -2,47 +2,55 @@ ...@@ -2,47 +2,55 @@
The script loads the full graph to the training device. The script loads the full graph to the training device.
""" """
import os, time
import argparse import argparse
import logging import logging
import os
import random import random
import string import string
import time
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from data import MovieLens from data import MovieLens
from model import BiDecoder, GCMCLayer from model import BiDecoder, GCMCLayer
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger from utils import (MetricLogger, get_activation, get_optimizer, torch_net_info,
torch_total_param_num)
class Net(nn.Module): class Net(nn.Module):
def __init__(self, args): def __init__(self, args):
super(Net, self).__init__() super(Net, self).__init__()
self._act = get_activation(args.model_activation) self._act = get_activation(args.model_activation)
self.encoder = GCMCLayer(args.rating_vals, self.encoder = GCMCLayer(
args.src_in_units, args.rating_vals,
args.dst_in_units, args.src_in_units,
args.gcn_agg_units, args.dst_in_units,
args.gcn_out_units, args.gcn_agg_units,
args.gcn_dropout, args.gcn_out_units,
args.gcn_agg_accum, args.gcn_dropout,
agg_act=self._act, args.gcn_agg_accum,
share_user_item_param=args.share_param, agg_act=self._act,
device=args.device) share_user_item_param=args.share_param,
self.decoder = BiDecoder(in_units=args.gcn_out_units, device=args.device,
num_classes=len(args.rating_vals), )
num_basis=args.gen_r_num_basis_func) self.decoder = BiDecoder(
in_units=args.gcn_out_units,
num_classes=len(args.rating_vals),
num_basis=args.gen_r_num_basis_func,
)
def forward(self, enc_graph, dec_graph, ufeat, ifeat): def forward(self, enc_graph, dec_graph, ufeat, ifeat):
user_out, movie_out = self.encoder( user_out, movie_out = self.encoder(enc_graph, ufeat, ifeat)
enc_graph,
ufeat,
ifeat)
pred_ratings = self.decoder(dec_graph, user_out, movie_out) pred_ratings = self.decoder(dec_graph, user_out, movie_out)
return pred_ratings return pred_ratings
def evaluate(args, net, dataset, segment='valid'):
def evaluate(args, net, dataset, segment="valid"):
possible_rating_values = dataset.possible_rating_values possible_rating_values = dataset.possible_rating_values
nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(args.device) nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(
args.device
)
if segment == "valid": if segment == "valid":
rating_values = dataset.valid_truths rating_values = dataset.valid_truths
...@@ -58,18 +66,27 @@ def evaluate(args, net, dataset, segment='valid'): ...@@ -58,18 +66,27 @@ def evaluate(args, net, dataset, segment='valid'):
# Evaluate RMSE # Evaluate RMSE
net.eval() net.eval()
with th.no_grad(): with th.no_grad():
pred_ratings = net(enc_graph, dec_graph, pred_ratings = net(
dataset.user_feature, dataset.movie_feature) enc_graph, dec_graph, dataset.user_feature, dataset.movie_feature
real_pred_ratings = (th.softmax(pred_ratings, dim=1) * )
nd_possible_rating_values.view(1, -1)).sum(dim=1) real_pred_ratings = (
rmse = ((real_pred_ratings - rating_values) ** 2.).mean().item() th.softmax(pred_ratings, dim=1) * nd_possible_rating_values.view(1, -1)
).sum(dim=1)
rmse = ((real_pred_ratings - rating_values) ** 2.0).mean().item()
rmse = np.sqrt(rmse) rmse = np.sqrt(rmse)
return rmse return rmse
def train(args): def train(args):
print(args) print(args)
dataset = MovieLens(args.data_name, args.device, use_one_hot_fea=args.use_one_hot_fea, symm=args.gcn_agg_norm_symm, dataset = MovieLens(
test_ratio=args.data_test_ratio, valid_ratio=args.data_valid_ratio) args.data_name,
args.device,
use_one_hot_fea=args.use_one_hot_fea,
symm=args.gcn_agg_norm_symm,
test_ratio=args.data_test_ratio,
valid_ratio=args.data_valid_ratio,
)
print("Loading data finished ...\n") print("Loading data finished ...\n")
args.src_in_units = dataset.user_feature_shape[1] args.src_in_units = dataset.user_feature_shape[1]
...@@ -79,10 +96,14 @@ def train(args): ...@@ -79,10 +96,14 @@ def train(args):
### build the net ### build the net
net = Net(args=args) net = Net(args=args)
net = net.to(args.device) net = net.to(args.device)
nd_possible_rating_values = th.FloatTensor(dataset.possible_rating_values).to(args.device) nd_possible_rating_values = th.FloatTensor(
dataset.possible_rating_values
).to(args.device)
rating_loss_net = nn.CrossEntropyLoss() rating_loss_net = nn.CrossEntropyLoss()
learning_rate = args.train_lr learning_rate = args.train_lr
optimizer = get_optimizer(args.train_optimizer)(net.parameters(), lr=learning_rate) optimizer = get_optimizer(args.train_optimizer)(
net.parameters(), lr=learning_rate
)
print("Loading network finished ...\n") print("Loading network finished ...\n")
### perpare training data ### perpare training data
...@@ -90,12 +111,21 @@ def train(args): ...@@ -90,12 +111,21 @@ def train(args):
train_gt_ratings = dataset.train_truths train_gt_ratings = dataset.train_truths
### prepare the logger ### prepare the logger
train_loss_logger = MetricLogger(['iter', 'loss', 'rmse'], ['%d', '%.4f', '%.4f'], train_loss_logger = MetricLogger(
os.path.join(args.save_dir, 'train_loss%d.csv' % args.save_id)) ["iter", "loss", "rmse"],
valid_loss_logger = MetricLogger(['iter', 'rmse'], ['%d', '%.4f'], ["%d", "%.4f", "%.4f"],
os.path.join(args.save_dir, 'valid_loss%d.csv' % args.save_id)) os.path.join(args.save_dir, "train_loss%d.csv" % args.save_id),
test_loss_logger = MetricLogger(['iter', 'rmse'], ['%d', '%.4f'], )
os.path.join(args.save_dir, 'test_loss%d.csv' % args.save_id)) valid_loss_logger = MetricLogger(
["iter", "rmse"],
["%d", "%.4f"],
os.path.join(args.save_dir, "valid_loss%d.csv" % args.save_id),
)
test_loss_logger = MetricLogger(
["iter", "rmse"],
["%d", "%.4f"],
os.path.join(args.save_dir, "test_loss%d.csv" % args.save_id),
)
### declare the loss information ### declare the loss information
best_valid_rmse = np.inf best_valid_rmse = np.inf
...@@ -118,8 +148,12 @@ def train(args): ...@@ -118,8 +148,12 @@ def train(args):
if iter_idx > 3: if iter_idx > 3:
t0 = time.time() t0 = time.time()
net.train() net.train()
pred_ratings = net(dataset.train_enc_graph, dataset.train_dec_graph, pred_ratings = net(
dataset.user_feature, dataset.movie_feature) dataset.train_enc_graph,
dataset.train_dec_graph,
dataset.user_feature,
dataset.movie_feature,
)
loss = rating_loss_net(pred_ratings, train_gt_labels).mean() loss = rating_loss_net(pred_ratings, train_gt_labels).mean()
count_loss += loss.item() count_loss += loss.item()
optimizer.zero_grad() optimizer.zero_grad()
...@@ -132,97 +166,148 @@ def train(args): ...@@ -132,97 +166,148 @@ def train(args):
if iter_idx == 1: if iter_idx == 1:
print("Total #Param of net: %d" % (torch_total_param_num(net))) print("Total #Param of net: %d" % (torch_total_param_num(net)))
print(torch_net_info(net, save_path=os.path.join(args.save_dir, 'net%d.txt' % args.save_id))) print(
torch_net_info(
net,
save_path=os.path.join(
args.save_dir, "net%d.txt" % args.save_id
),
)
)
real_pred_ratings = (th.softmax(pred_ratings, dim=1) * real_pred_ratings = (
nd_possible_rating_values.view(1, -1)).sum(dim=1) th.softmax(pred_ratings, dim=1)
* nd_possible_rating_values.view(1, -1)
).sum(dim=1)
rmse = ((real_pred_ratings - train_gt_ratings) ** 2).sum() rmse = ((real_pred_ratings - train_gt_ratings) ** 2).sum()
count_rmse += rmse.item() count_rmse += rmse.item()
count_num += pred_ratings.shape[0] count_num += pred_ratings.shape[0]
if iter_idx % args.train_log_interval == 0: if iter_idx % args.train_log_interval == 0:
train_loss_logger.log(iter=iter_idx, train_loss_logger.log(
loss=count_loss/(iter_idx+1), rmse=count_rmse/count_num) iter=iter_idx,
logging_str = "Iter={}, loss={:.4f}, rmse={:.4f}, time={:.4f}".format( loss=count_loss / (iter_idx + 1),
iter_idx, count_loss/iter_idx, count_rmse/count_num, rmse=count_rmse / count_num,
np.average(dur)) )
logging_str = (
"Iter={}, loss={:.4f}, rmse={:.4f}, time={:.4f}".format(
iter_idx,
count_loss / iter_idx,
count_rmse / count_num,
np.average(dur),
)
)
count_rmse = 0 count_rmse = 0
count_num = 0 count_num = 0
if iter_idx % args.train_valid_interval == 0: if iter_idx % args.train_valid_interval == 0:
valid_rmse = evaluate(args=args, net=net, dataset=dataset, segment='valid') valid_rmse = evaluate(
valid_loss_logger.log(iter = iter_idx, rmse = valid_rmse) args=args, net=net, dataset=dataset, segment="valid"
logging_str += ',\tVal RMSE={:.4f}'.format(valid_rmse) )
valid_loss_logger.log(iter=iter_idx, rmse=valid_rmse)
logging_str += ",\tVal RMSE={:.4f}".format(valid_rmse)
if valid_rmse < best_valid_rmse: if valid_rmse < best_valid_rmse:
best_valid_rmse = valid_rmse best_valid_rmse = valid_rmse
no_better_valid = 0 no_better_valid = 0
best_iter = iter_idx best_iter = iter_idx
test_rmse = evaluate(args=args, net=net, dataset=dataset, segment='test') test_rmse = evaluate(
args=args, net=net, dataset=dataset, segment="test"
)
best_test_rmse = test_rmse best_test_rmse = test_rmse
test_loss_logger.log(iter=iter_idx, rmse=test_rmse) test_loss_logger.log(iter=iter_idx, rmse=test_rmse)
logging_str += ', Test RMSE={:.4f}'.format(test_rmse) logging_str += ", Test RMSE={:.4f}".format(test_rmse)
else: else:
no_better_valid += 1 no_better_valid += 1
if no_better_valid > args.train_early_stopping_patience\ if (
and learning_rate <= args.train_min_lr: no_better_valid > args.train_early_stopping_patience
logging.info("Early stopping threshold reached. Stop training.") and learning_rate <= args.train_min_lr
):
logging.info(
"Early stopping threshold reached. Stop training."
)
break break
if no_better_valid > args.train_decay_patience: if no_better_valid > args.train_decay_patience:
new_lr = max(learning_rate * args.train_lr_decay_factor, args.train_min_lr) new_lr = max(
learning_rate * args.train_lr_decay_factor,
args.train_min_lr,
)
if new_lr < learning_rate: if new_lr < learning_rate:
learning_rate = new_lr learning_rate = new_lr
logging.info("\tChange the LR to %g" % new_lr) logging.info("\tChange the LR to %g" % new_lr)
for p in optimizer.param_groups: for p in optimizer.param_groups:
p['lr'] = learning_rate p["lr"] = learning_rate
no_better_valid = 0 no_better_valid = 0
if iter_idx % args.train_log_interval == 0: if iter_idx % args.train_log_interval == 0:
print(logging_str) print(logging_str)
print('Best Iter Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}'.format( print(
best_iter, best_valid_rmse, best_test_rmse)) "Best Iter Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}".format(
best_iter, best_valid_rmse, best_test_rmse
)
)
train_loss_logger.close() train_loss_logger.close()
valid_loss_logger.close() valid_loss_logger.close()
test_loss_logger.close() test_loss_logger.close()
def config(): def config():
parser = argparse.ArgumentParser(description='GCMC') parser = argparse.ArgumentParser(description="GCMC")
parser.add_argument('--seed', default=123, type=int) parser.add_argument("--seed", default=123, type=int)
parser.add_argument('--device', default='0', type=int, parser.add_argument(
help='Running device. E.g `--device 0`, if using cpu, set `--device -1`') "--device",
parser.add_argument('--save_dir', type=str, help='The saving directory') default="0",
parser.add_argument('--save_id', type=int, help='The saving log id') type=int,
parser.add_argument('--silent', action='store_true') help="Running device. E.g `--device 0`, if using cpu, set `--device -1`",
parser.add_argument('--data_name', default='ml-1m', type=str, )
help='The dataset name: ml-100k, ml-1m, ml-10m') parser.add_argument("--save_dir", type=str, help="The saving directory")
parser.add_argument('--data_test_ratio', type=float, default=0.1) ## for ml-100k the test ration is 0.2 parser.add_argument("--save_id", type=int, help="The saving log id")
parser.add_argument('--data_valid_ratio', type=float, default=0.1) parser.add_argument("--silent", action="store_true")
parser.add_argument('--use_one_hot_fea', action='store_true', default=False) parser.add_argument(
parser.add_argument('--model_activation', type=str, default="leaky") "--data_name",
parser.add_argument('--gcn_dropout', type=float, default=0.7) default="ml-1m",
parser.add_argument('--gcn_agg_norm_symm', type=bool, default=True) type=str,
parser.add_argument('--gcn_agg_units', type=int, default=500) help="The dataset name: ml-100k, ml-1m, ml-10m",
parser.add_argument('--gcn_agg_accum', type=str, default="sum") )
parser.add_argument('--gcn_out_units', type=int, default=75) parser.add_argument(
parser.add_argument('--gen_r_num_basis_func', type=int, default=2) "--data_test_ratio", type=float, default=0.1
parser.add_argument('--train_max_iter', type=int, default=2000) ) ## for ml-100k the test ration is 0.2
parser.add_argument('--train_log_interval', type=int, default=1) parser.add_argument("--data_valid_ratio", type=float, default=0.1)
parser.add_argument('--train_valid_interval', type=int, default=1) parser.add_argument("--use_one_hot_fea", action="store_true", default=False)
parser.add_argument('--train_optimizer', type=str, default="adam") parser.add_argument("--model_activation", type=str, default="leaky")
parser.add_argument('--train_grad_clip', type=float, default=1.0) parser.add_argument("--gcn_dropout", type=float, default=0.7)
parser.add_argument('--train_lr', type=float, default=0.01) parser.add_argument("--gcn_agg_norm_symm", type=bool, default=True)
parser.add_argument('--train_min_lr', type=float, default=0.001) parser.add_argument("--gcn_agg_units", type=int, default=500)
parser.add_argument('--train_lr_decay_factor', type=float, default=0.5) parser.add_argument("--gcn_agg_accum", type=str, default="sum")
parser.add_argument('--train_decay_patience', type=int, default=50) parser.add_argument("--gcn_out_units", type=int, default=75)
parser.add_argument('--train_early_stopping_patience', type=int, default=100) parser.add_argument("--gen_r_num_basis_func", type=int, default=2)
parser.add_argument('--share_param', default=False, action='store_true') parser.add_argument("--train_max_iter", type=int, default=2000)
parser.add_argument("--train_log_interval", type=int, default=1)
parser.add_argument("--train_valid_interval", type=int, default=1)
parser.add_argument("--train_optimizer", type=str, default="adam")
parser.add_argument("--train_grad_clip", type=float, default=1.0)
parser.add_argument("--train_lr", type=float, default=0.01)
parser.add_argument("--train_min_lr", type=float, default=0.001)
parser.add_argument("--train_lr_decay_factor", type=float, default=0.5)
parser.add_argument("--train_decay_patience", type=int, default=50)
parser.add_argument(
"--train_early_stopping_patience", type=int, default=100
)
parser.add_argument("--share_param", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
args.device = th.device(args.device) if args.device >= 0 else th.device('cpu') args.device = (
th.device(args.device) if args.device >= 0 else th.device("cpu")
)
### configure save_fir to save all the info ### configure save_fir to save all the info
if args.save_dir is None: if args.save_dir is None:
args.save_dir = args.data_name+"_" + ''.join(random.choices(string.ascii_uppercase + string.digits, k=2)) args.save_dir = (
args.data_name
+ "_"
+ "".join(
random.choices(string.ascii_uppercase + string.digits, k=2)
)
)
if args.save_id is None: if args.save_id is None:
args.save_id = np.random.randint(20) args.save_id = np.random.randint(20)
args.save_dir = os.path.join("log", args.save_dir) args.save_dir = os.path.join("log", args.save_dir)
...@@ -232,7 +317,7 @@ def config(): ...@@ -232,7 +317,7 @@ def config():
return args return args
if __name__ == '__main__': if __name__ == "__main__":
args = config() args = config()
np.random.seed(args.seed) np.random.seed(args.seed)
th.manual_seed(args.seed) th.manual_seed(args.seed)
......
import csv import csv
import re import re
import torch as th from collections import OrderedDict
import numpy as np import numpy as np
import torch as th
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from collections import OrderedDict
class MetricLogger(object): class MetricLogger(object):
def __init__(self, attr_names, parse_formats, save_path): def __init__(self, attr_names, parse_formats, save_path):
self._attr_format_dict = OrderedDict(zip(attr_names, parse_formats)) self._attr_format_dict = OrderedDict(zip(attr_names, parse_formats))
self._file = open(save_path, 'w') self._file = open(save_path, "w")
self._csv = csv.writer(self._file) self._csv = csv.writer(self._file)
self._csv.writerow(attr_names) self._csv.writerow(attr_names)
self._file.flush() self._file.flush()
def log(self, **kwargs): def log(self, **kwargs):
self._csv.writerow([parse_format % kwargs[attr_name] self._csv.writerow(
for attr_name, parse_format in self._attr_format_dict.items()]) [
parse_format % kwargs[attr_name]
for attr_name, parse_format in self._attr_format_dict.items()
]
)
self._file.flush() self._file.flush()
def close(self): def close(self):
...@@ -28,13 +34,15 @@ def torch_total_param_num(net): ...@@ -28,13 +34,15 @@ def torch_total_param_num(net):
def torch_net_info(net, save_path=None): def torch_net_info(net, save_path=None):
info_str = 'Total Param Number: {}\n'.format(torch_total_param_num(net)) +\ info_str = (
'Params:\n' "Total Param Number: {}\n".format(torch_total_param_num(net))
+ "Params:\n"
)
for k, v in net.named_parameters(): for k, v in net.named_parameters():
info_str += '\t{}: {}, {}\n'.format(k, v.shape, np.prod(v.shape)) info_str += "\t{}: {}, {}\n".format(k, v.shape, np.prod(v.shape))
info_str += str(net) info_str += str(net)
if save_path is not None: if save_path is not None:
with open(save_path, 'w') as f: with open(save_path, "w") as f:
f.write(info_str) f.write(info_str)
return info_str return info_str
...@@ -53,15 +61,15 @@ def get_activation(act): ...@@ -53,15 +61,15 @@ def get_activation(act):
if act is None: if act is None:
return lambda x: x return lambda x: x
if isinstance(act, str): if isinstance(act, str):
if act == 'leaky': if act == "leaky":
return nn.LeakyReLU(0.1) return nn.LeakyReLU(0.1)
elif act == 'relu': elif act == "relu":
return nn.ReLU() return nn.ReLU()
elif act == 'tanh': elif act == "tanh":
return nn.Tanh() return nn.Tanh()
elif act == 'sigmoid': elif act == "sigmoid":
return nn.Sigmoid() return nn.Sigmoid()
elif act == 'softsign': elif act == "softsign":
return nn.Softsign() return nn.Softsign()
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -70,13 +78,13 @@ def get_activation(act): ...@@ -70,13 +78,13 @@ def get_activation(act):
def get_optimizer(opt): def get_optimizer(opt):
if opt == 'sgd': if opt == "sgd":
return optim.SGD return optim.SGD
elif opt == 'adam': elif opt == "adam":
return optim.Adam return optim.Adam
else: else:
raise NotImplementedError raise NotImplementedError
def to_etype_name(rating): def to_etype_name(rating):
return str(rating).replace('.', '_') return str(rating).replace(".", "_")
import argparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
import dgl.nn as dglnn import dgl.nn as dglnn
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from dgl import AddSelfLoop from dgl import AddSelfLoop
import argparse from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, in_size, hid_size, out_size): def __init__(self, in_size, hid_size, out_size):
super().__init__() super().__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
# two-layer GCN # two-layer GCN
self.layers.append(dglnn.GraphConv(in_size, hid_size, activation=F.relu)) self.layers.append(
dglnn.GraphConv(in_size, hid_size, activation=F.relu)
)
self.layers.append(dglnn.GraphConv(hid_size, out_size)) self.layers.append(dglnn.GraphConv(hid_size, out_size))
self.dropout = nn.Dropout(0.5) self.dropout = nn.Dropout(0.5)
...@@ -23,7 +28,8 @@ class GCN(nn.Module): ...@@ -23,7 +28,8 @@ class GCN(nn.Module):
h = self.dropout(h) h = self.dropout(h)
h = layer(g, h) h = layer(g, h)
return h return h
def evaluate(g, features, labels, mask, model): def evaluate(g, features, labels, mask, model):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -51,49 +57,59 @@ def train(g, features, labels, masks, model): ...@@ -51,49 +57,59 @@ def train(g, features, labels, masks, model):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
acc = evaluate(g, features, labels, val_mask, model) acc = evaluate(g, features, labels, val_mask, model)
print("Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} " print(
. format(epoch, loss.item(), acc)) "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
epoch, loss.item(), acc
)
)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="cora", parser.add_argument(
help="Dataset name ('cora', 'citeseer', 'pubmed').") "--dataset",
type=str,
default="cora",
help="Dataset name ('cora', 'citeseer', 'pubmed').",
)
args = parser.parse_args() args = parser.parse_args()
print(f'Training with DGL built-in GraphConv module.') print(f"Training with DGL built-in GraphConv module.")
# load and preprocess dataset # load and preprocess dataset
transform = AddSelfLoop() # by default, it will first remove self-loops to prevent duplication transform = (
if args.dataset == 'cora': AddSelfLoop()
) # by default, it will first remove self-loops to prevent duplication
if args.dataset == "cora":
data = CoraGraphDataset(transform=transform) data = CoraGraphDataset(transform=transform)
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset(transform=transform) data = CiteseerGraphDataset(transform=transform)
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset(transform=transform) data = PubmedGraphDataset(transform=transform)
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
g = data[0] g = data[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
g = g.int().to(device) g = g.int().to(device)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
masks = g.ndata['train_mask'], g.ndata['val_mask'], g.ndata['test_mask'] masks = g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]
# normalization # normalization
degs = g.in_degrees().float() degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5).to(device) norm = torch.pow(degs, -0.5).to(device)
norm[torch.isinf(norm)] = 0 norm[torch.isinf(norm)] = 0
g.ndata['norm'] = norm.unsqueeze(1) g.ndata["norm"] = norm.unsqueeze(1)
# create GCN model # create GCN model
in_size = features.shape[1] in_size = features.shape[1]
out_size = data.num_classes out_size = data.num_classes
model = GCN(in_size, 16, out_size).to(device) model = GCN(in_size, 16, out_size).to(device)
# model training # model training
print('Training...') print("Training...")
train(g, features, labels, masks, model) train(g, features, labels, masks, model)
# test the model # test the model
print('Testing...') print("Testing...")
acc = evaluate(g, features, labels, masks[2], model) acc = evaluate(g, features, labels, masks[2], model)
print("Test accuracy {:.4f}".format(acc)) print("Test accuracy {:.4f}".format(acc))
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
from dgl.nn import GATConv
from torch.nn import LSTM from torch.nn import LSTM
from dgl.nn import GATConv
class GeniePathConv(nn.Module): class GeniePathConv(nn.Module):
def __init__(self, in_dim, hid_dim, out_dim, num_heads=1, residual=False): def __init__(self, in_dim, hid_dim, out_dim, num_heads=1, residual=False):
super(GeniePathConv, self).__init__() super(GeniePathConv, self).__init__()
self.breadth_func = GATConv(in_dim, hid_dim, num_heads=num_heads, residual=residual) self.breadth_func = GATConv(
in_dim, hid_dim, num_heads=num_heads, residual=residual
)
self.depth_func = LSTM(hid_dim, out_dim) self.depth_func = LSTM(hid_dim, out_dim)
def forward(self, graph, x, h, c): def forward(self, graph, x, h, c):
...@@ -20,14 +23,30 @@ class GeniePathConv(nn.Module): ...@@ -20,14 +23,30 @@ class GeniePathConv(nn.Module):
class GeniePath(nn.Module): class GeniePath(nn.Module):
def __init__(self, in_dim, out_dim, hid_dim=16, num_layers=2, num_heads=1, residual=False): def __init__(
self,
in_dim,
out_dim,
hid_dim=16,
num_layers=2,
num_heads=1,
residual=False,
):
super(GeniePath, self).__init__() super(GeniePath, self).__init__()
self.hid_dim = hid_dim self.hid_dim = hid_dim
self.linear1 = nn.Linear(in_dim, hid_dim) self.linear1 = nn.Linear(in_dim, hid_dim)
self.linear2 = nn.Linear(hid_dim, out_dim) self.linear2 = nn.Linear(hid_dim, out_dim)
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
for i in range(num_layers): for i in range(num_layers):
self.layers.append(GeniePathConv(hid_dim, hid_dim, hid_dim, num_heads=num_heads, residual=residual)) self.layers.append(
GeniePathConv(
hid_dim,
hid_dim,
hid_dim,
num_heads=num_heads,
residual=residual,
)
)
def forward(self, graph, x): def forward(self, graph, x):
h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device) h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
...@@ -42,7 +61,15 @@ class GeniePath(nn.Module): ...@@ -42,7 +61,15 @@ class GeniePath(nn.Module):
class GeniePathLazy(nn.Module): class GeniePathLazy(nn.Module):
def __init__(self, in_dim, out_dim, hid_dim=16, num_layers=2, num_heads=1, residual=False): def __init__(
self,
in_dim,
out_dim,
hid_dim=16,
num_layers=2,
num_heads=1,
residual=False,
):
super(GeniePathLazy, self).__init__() super(GeniePathLazy, self).__init__()
self.hid_dim = hid_dim self.hid_dim = hid_dim
self.linear1 = nn.Linear(in_dim, hid_dim) self.linear1 = nn.Linear(in_dim, hid_dim)
...@@ -50,8 +77,12 @@ class GeniePathLazy(nn.Module): ...@@ -50,8 +77,12 @@ class GeniePathLazy(nn.Module):
self.breaths = nn.ModuleList() self.breaths = nn.ModuleList()
self.depths = nn.ModuleList() self.depths = nn.ModuleList()
for i in range(num_layers): for i in range(num_layers):
self.breaths.append(GATConv(hid_dim, hid_dim, num_heads=num_heads, residual=residual)) self.breaths.append(
self.depths.append(LSTM(hid_dim*2, hid_dim)) GATConv(
hid_dim, hid_dim, num_heads=num_heads, residual=residual
)
)
self.depths.append(LSTM(hid_dim * 2, hid_dim))
def forward(self, graph, x): def forward(self, graph, x):
h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device) h = th.zeros(1, x.shape[0], self.hid_dim).to(x.device)
......
...@@ -3,27 +3,27 @@ import argparse ...@@ -3,27 +3,27 @@ import argparse
import numpy as np import numpy as np
import torch as th import torch as th
import torch.optim as optim import torch.optim as optim
from dgl.data import PPIDataset from model import GeniePath, GeniePathLazy
from dgl.dataloading import GraphDataLoader
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from model import GeniePath, GeniePathLazy from dgl.data import PPIDataset
from dgl.dataloading import GraphDataLoader
def evaluate(model, loss_fn, dataloader, device='cpu'): def evaluate(model, loss_fn, dataloader, device="cpu"):
loss = 0 loss = 0
f1 = 0 f1 = 0
num_blocks = 0 num_blocks = 0
for subgraph in dataloader: for subgraph in dataloader:
subgraph = subgraph.to(device) subgraph = subgraph.to(device)
label = subgraph.ndata['label'].to(device) label = subgraph.ndata["label"].to(device)
feat = subgraph.ndata['feat'] feat = subgraph.ndata["feat"]
logits = model(subgraph, feat) logits = model(subgraph, feat)
# compute loss # compute loss
loss += loss_fn(logits, label).item() loss += loss_fn(logits, label).item()
predict = np.where(logits.data.cpu().numpy() >= 0., 1, 0) predict = np.where(logits.data.cpu().numpy() >= 0.0, 1, 0)
f1 += f1_score(label.cpu(), predict, average='micro') f1 += f1_score(label.cpu(), predict, average="micro")
num_blocks += 1 num_blocks += 1
return f1 / num_blocks, loss / num_blocks return f1 / num_blocks, loss / num_blocks
...@@ -32,40 +32,48 @@ def evaluate(model, loss_fn, dataloader, device='cpu'): ...@@ -32,40 +32,48 @@ def evaluate(model, loss_fn, dataloader, device='cpu'):
def main(args): def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load dataset # Load dataset
train_dataset = PPIDataset(mode='train') train_dataset = PPIDataset(mode="train")
valid_dataset = PPIDataset(mode='valid') valid_dataset = PPIDataset(mode="valid")
test_dataset = PPIDataset(mode='test') test_dataset = PPIDataset(mode="test")
train_dataloader = GraphDataLoader(train_dataset, batch_size=args.batch_size) train_dataloader = GraphDataLoader(
valid_dataloader = GraphDataLoader(valid_dataset, batch_size=args.batch_size) train_dataset, batch_size=args.batch_size
)
valid_dataloader = GraphDataLoader(
valid_dataset, batch_size=args.batch_size
)
test_dataloader = GraphDataLoader(test_dataset, batch_size=args.batch_size) test_dataloader = GraphDataLoader(test_dataset, batch_size=args.batch_size)
# check cuda # check cuda
if args.gpu >= 0 and th.cuda.is_available(): if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu) device = "cuda:{}".format(args.gpu)
else: else:
device = 'cpu' device = "cpu"
num_classes = train_dataset.num_labels num_classes = train_dataset.num_labels
# Extract node features # Extract node features
graph = train_dataset[0] graph = train_dataset[0]
feat = graph.ndata['feat'] feat = graph.ndata["feat"]
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
if args.lazy: if args.lazy:
model = GeniePathLazy(in_dim=feat.shape[-1], model = GeniePathLazy(
out_dim=num_classes, in_dim=feat.shape[-1],
hid_dim=args.hid_dim, out_dim=num_classes,
num_layers=args.num_layers, hid_dim=args.hid_dim,
num_heads=args.num_heads, num_layers=args.num_layers,
residual=args.residual) num_heads=args.num_heads,
residual=args.residual,
)
else: else:
model = GeniePath(in_dim=feat.shape[-1], model = GeniePath(
out_dim=num_classes, in_dim=feat.shape[-1],
hid_dim=args.hid_dim, out_dim=num_classes,
num_layers=args.num_layers, hid_dim=args.hid_dim,
num_heads=args.num_heads, num_layers=args.num_layers,
residual=args.residual) num_heads=args.num_heads,
residual=args.residual,
)
model = model.to(device) model = model.to(device)
...@@ -81,15 +89,15 @@ def main(args): ...@@ -81,15 +89,15 @@ def main(args):
num_blocks = 0 num_blocks = 0
for subgraph in train_dataloader: for subgraph in train_dataloader:
subgraph = subgraph.to(device) subgraph = subgraph.to(device)
label = subgraph.ndata['label'] label = subgraph.ndata["label"]
feat = subgraph.ndata['feat'] feat = subgraph.ndata["feat"]
logits = model(subgraph, feat) logits = model(subgraph, feat)
# compute loss # compute loss
batch_loss = loss_fn(logits, label) batch_loss = loss_fn(logits, label)
tr_loss += batch_loss.item() tr_loss += batch_loss.item()
tr_predict = np.where(logits.data.cpu().numpy() >= 0., 1, 0) tr_predict = np.where(logits.data.cpu().numpy() >= 0.0, 1, 0)
tr_f1 += f1_score(label.cpu(), tr_predict, average='micro') tr_f1 += f1_score(label.cpu(), tr_predict, average="micro")
num_blocks += 1 num_blocks += 1
# backward # backward
...@@ -101,28 +109,64 @@ def main(args): ...@@ -101,28 +109,64 @@ def main(args):
model.eval() model.eval()
val_f1, val_loss = evaluate(model, loss_fn, valid_dataloader, device) val_f1, val_loss = evaluate(model, loss_fn, valid_dataloader, device)
print("In epoch {}, Train F1: {:.4f} | Train Loss: {:.4f}; Valid F1: {:.4f} | Valid loss: {:.4f}". print(
format(epoch, tr_f1 / num_blocks, tr_loss / num_blocks, val_f1, val_loss)) "In epoch {}, Train F1: {:.4f} | Train Loss: {:.4f}; Valid F1: {:.4f} | Valid loss: {:.4f}".format(
epoch,
tr_f1 / num_blocks,
tr_loss / num_blocks,
val_f1,
val_loss,
)
)
# Test after all epoch # Test after all epoch
model.eval() model.eval()
test_f1, test_loss = evaluate(model, loss_fn, test_dataloader, device) test_f1, test_loss = evaluate(model, loss_fn, test_dataloader, device)
print("Test F1: {:.4f} | Test loss: {:.4f}". print("Test F1: {:.4f} | Test loss: {:.4f}".format(test_f1, test_loss))
format(test_f1, test_loss))
if __name__ == "__main__":
if __name__ == '__main__': parser = argparse.ArgumentParser(description="GeniePath")
parser = argparse.ArgumentParser(description='GeniePath') parser.add_argument(
parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.") "--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU."
parser.add_argument("--hid_dim", type=int, default=256, help="Hidden layer dimension") )
parser.add_argument("--num_layers", type=int, default=3, help="Number of GeniePath layers") parser.add_argument(
parser.add_argument("--max_epoch", type=int, default=1000, help="The max number of epochs. Default: 1000") "--hid_dim", type=int, default=256, help="Hidden layer dimension"
parser.add_argument("--lr", type=float, default=0.0004, help="Learning rate. Default: 0.0004") )
parser.add_argument("--num_heads", type=int, default=1, help="Number of head in breadth function. Default: 1") parser.add_argument(
parser.add_argument("--residual", type=bool, default=False, help="Residual in GAT or not") "--num_layers", type=int, default=3, help="Number of GeniePath layers"
parser.add_argument("--batch_size", type=int, default=2, help="Batch size of graph dataloader") )
parser.add_argument("--lazy", type=bool, default=False, help="Variant GeniePath-Lazy") parser.add_argument(
"--max_epoch",
type=int,
default=1000,
help="The max number of epochs. Default: 1000",
)
parser.add_argument(
"--lr",
type=float,
default=0.0004,
help="Learning rate. Default: 0.0004",
)
parser.add_argument(
"--num_heads",
type=int,
default=1,
help="Number of head in breadth function. Default: 1",
)
parser.add_argument(
"--residual", type=bool, default=False, help="Residual in GAT or not"
)
parser.add_argument(
"--batch_size",
type=int,
default=2,
help="Batch size of graph dataloader",
)
parser.add_argument(
"--lazy", type=bool, default=False, help="Variant GeniePath-Lazy"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -2,10 +2,10 @@ import argparse ...@@ -2,10 +2,10 @@ import argparse
import torch as th import torch as th
import torch.optim as optim import torch.optim as optim
from dgl.data import PubmedGraphDataset from model import GeniePath, GeniePathLazy
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from model import GeniePath, GeniePathLazy from dgl.data import PubmedGraphDataset
def main(args): def main(args):
...@@ -16,22 +16,22 @@ def main(args): ...@@ -16,22 +16,22 @@ def main(args):
# check cuda # check cuda
if args.gpu >= 0 and th.cuda.is_available(): if args.gpu >= 0 and th.cuda.is_available():
device = 'cuda:{}'.format(args.gpu) device = "cuda:{}".format(args.gpu)
else: else:
device = 'cpu' device = "cpu"
num_classes = dataset.num_classes num_classes = dataset.num_classes
# retrieve label of ground truth # retrieve label of ground truth
label = graph.ndata['label'].to(device) label = graph.ndata["label"].to(device)
# Extract node features # Extract node features
feat = graph.ndata['feat'].to(device) feat = graph.ndata["feat"].to(device)
# retrieve masks for train/validation/test # retrieve masks for train/validation/test
train_mask = graph.ndata['train_mask'] train_mask = graph.ndata["train_mask"]
val_mask = graph.ndata['val_mask'] val_mask = graph.ndata["val_mask"]
test_mask = graph.ndata['test_mask'] test_mask = graph.ndata["test_mask"]
train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device) train_idx = th.nonzero(train_mask, as_tuple=False).squeeze(1).to(device)
val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device) val_idx = th.nonzero(val_mask, as_tuple=False).squeeze(1).to(device)
...@@ -41,19 +41,23 @@ def main(args): ...@@ -41,19 +41,23 @@ def main(args):
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
if args.lazy: if args.lazy:
model = GeniePathLazy(in_dim=feat.shape[-1], model = GeniePathLazy(
out_dim=num_classes, in_dim=feat.shape[-1],
hid_dim=args.hid_dim, out_dim=num_classes,
num_layers=args.num_layers, hid_dim=args.hid_dim,
num_heads=args.num_heads, num_layers=args.num_layers,
residual=args.residual) num_heads=args.num_heads,
residual=args.residual,
)
else: else:
model = GeniePath(in_dim=feat.shape[-1], model = GeniePath(
out_dim=num_classes, in_dim=feat.shape[-1],
hid_dim=args.hid_dim, out_dim=num_classes,
num_layers=args.num_layers, hid_dim=args.hid_dim,
num_heads=args.num_heads, num_layers=args.num_layers,
residual=args.residual) num_heads=args.num_heads,
residual=args.residual,
)
model = model.to(device) model = model.to(device)
...@@ -69,11 +73,15 @@ def main(args): ...@@ -69,11 +73,15 @@ def main(args):
# compute loss # compute loss
tr_loss = loss_fn(logits[train_idx], label[train_idx]) tr_loss = loss_fn(logits[train_idx], label[train_idx])
tr_acc = accuracy_score(label[train_idx].cpu(), logits[train_idx].argmax(dim=1).cpu()) tr_acc = accuracy_score(
label[train_idx].cpu(), logits[train_idx].argmax(dim=1).cpu()
)
# validation # validation
valid_loss = loss_fn(logits[val_idx], label[val_idx]) valid_loss = loss_fn(logits[val_idx], label[val_idx])
valid_acc = accuracy_score(label[val_idx].cpu(), logits[val_idx].argmax(dim=1).cpu()) valid_acc = accuracy_score(
label[val_idx].cpu(), logits[val_idx].argmax(dim=1).cpu()
)
# backward # backward
optimizer.zero_grad() optimizer.zero_grad()
...@@ -81,8 +89,11 @@ def main(args): ...@@ -81,8 +89,11 @@ def main(args):
optimizer.step() optimizer.step()
# Print out performance # Print out performance
print("In epoch {}, Train ACC: {:.4f} | Train Loss: {:.4f}; Valid ACC: {:.4f} | Valid loss: {:.4f}". print(
format(epoch, tr_acc, tr_loss.item(), valid_acc, valid_loss.item())) "In epoch {}, Train ACC: {:.4f} | Train Loss: {:.4f}; Valid ACC: {:.4f} | Valid loss: {:.4f}".format(
epoch, tr_acc, tr_loss.item(), valid_acc, valid_loss.item()
)
)
# Test after all epoch # Test after all epoch
model.eval() model.eval()
...@@ -92,22 +103,52 @@ def main(args): ...@@ -92,22 +103,52 @@ def main(args):
# compute loss # compute loss
test_loss = loss_fn(logits[test_idx], label[test_idx]) test_loss = loss_fn(logits[test_idx], label[test_idx])
test_acc = accuracy_score(label[test_idx].cpu(), logits[test_idx].argmax(dim=1).cpu()) test_acc = accuracy_score(
label[test_idx].cpu(), logits[test_idx].argmax(dim=1).cpu()
print("Test ACC: {:.4f} | Test loss: {:.4f}". )
format(test_acc, test_loss.item()))
print(
"Test ACC: {:.4f} | Test loss: {:.4f}".format(
if __name__ == '__main__': test_acc, test_loss.item()
parser = argparse.ArgumentParser(description='GeniePath') )
parser.add_argument("--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU.") )
parser.add_argument("--hid_dim", type=int, default=16, help="Hidden layer dimension")
parser.add_argument("--num_layers", type=int, default=2, help="Number of GeniePath layers")
parser.add_argument("--max_epoch", type=int, default=300, help="The max number of epochs. Default: 300") if __name__ == "__main__":
parser.add_argument("--lr", type=float, default=0.0004, help="Learning rate. Default: 0.0004") parser = argparse.ArgumentParser(description="GeniePath")
parser.add_argument("--num_heads", type=int, default=1, help="Number of head in breadth function. Default: 1") parser.add_argument(
parser.add_argument("--residual", type=bool, default=False, help="Residual in GAT or not") "--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU."
parser.add_argument("--lazy", type=bool, default=False, help="Variant GeniePath-Lazy") )
parser.add_argument(
"--hid_dim", type=int, default=16, help="Hidden layer dimension"
)
parser.add_argument(
"--num_layers", type=int, default=2, help="Number of GeniePath layers"
)
parser.add_argument(
"--max_epoch",
type=int,
default=300,
help="The max number of epochs. Default: 300",
)
parser.add_argument(
"--lr",
type=float,
default=0.0004,
help="Learning rate. Default: 0.0004",
)
parser.add_argument(
"--num_heads",
type=int,
default=1,
help="Number of head in breadth function. Default: 1",
)
parser.add_argument(
"--residual", type=bool, default=False, help="Residual in GAT or not"
)
parser.add_argument(
"--lazy", type=bool, default=False, help="Variant GeniePath-Lazy"
)
args = parser.parse_args() args = parser.parse_args()
th.manual_seed(16) th.manual_seed(16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment