import argparse import numpy as np import torch as th import torch.optim as optim from model import GeniePath, GeniePathLazy from sklearn.metrics import f1_score from dgl.data import PPIDataset from dgl.dataloading import GraphDataLoader def evaluate(model, loss_fn, dataloader, device="cpu"): loss = 0 f1 = 0 num_blocks = 0 for subgraph in dataloader: subgraph = subgraph.to(device) label = subgraph.ndata["label"].to(device) feat = subgraph.ndata["feat"] logits = model(subgraph, feat) # compute loss loss += loss_fn(logits, label).item() predict = np.where(logits.data.cpu().numpy() >= 0.0, 1, 0) f1 += f1_score(label.cpu(), predict, average="micro") num_blocks += 1 return f1 / num_blocks, loss / num_blocks def main(args): # Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Load dataset train_dataset = PPIDataset(mode="train") valid_dataset = PPIDataset(mode="valid") test_dataset = PPIDataset(mode="test") train_dataloader = GraphDataLoader( train_dataset, batch_size=args.batch_size ) valid_dataloader = GraphDataLoader( valid_dataset, batch_size=args.batch_size ) test_dataloader = GraphDataLoader(test_dataset, batch_size=args.batch_size) # check cuda if args.gpu >= 0 and th.cuda.is_available(): device = "cuda:{}".format(args.gpu) else: device = "cpu" num_classes = train_dataset.num_labels # Extract node features graph = train_dataset[0] feat = graph.ndata["feat"] # Step 2: Create model =================================================================== # if args.lazy: model = GeniePathLazy( in_dim=feat.shape[-1], out_dim=num_classes, hid_dim=args.hid_dim, num_layers=args.num_layers, num_heads=args.num_heads, residual=args.residual, ) else: model = GeniePath( in_dim=feat.shape[-1], out_dim=num_classes, hid_dim=args.hid_dim, num_layers=args.num_layers, num_heads=args.num_heads, residual=args.residual, ) model = model.to(device) # Step 3: Create training components ===================================================== # loss_fn = th.nn.BCEWithLogitsLoss() optimizer = optim.Adam(model.parameters(), lr=args.lr) # Step 4: training epochs =============================================================== # for epoch in range(args.max_epoch): model.train() tr_loss = 0 tr_f1 = 0 num_blocks = 0 for subgraph in train_dataloader: subgraph = subgraph.to(device) label = subgraph.ndata["label"] feat = subgraph.ndata["feat"] logits = model(subgraph, feat) # compute loss batch_loss = loss_fn(logits, label) tr_loss += batch_loss.item() tr_predict = np.where(logits.data.cpu().numpy() >= 0.0, 1, 0) tr_f1 += f1_score(label.cpu(), tr_predict, average="micro") num_blocks += 1 # backward optimizer.zero_grad() batch_loss.backward() optimizer.step() # validation model.eval() val_f1, val_loss = evaluate(model, loss_fn, valid_dataloader, device) print( "In epoch {}, Train F1: {:.4f} | Train Loss: {:.4f}; Valid F1: {:.4f} | Valid loss: {:.4f}".format( epoch, tr_f1 / num_blocks, tr_loss / num_blocks, val_f1, val_loss, ) ) # Test after all epoch model.eval() test_f1, test_loss = evaluate(model, loss_fn, test_dataloader, device) print("Test F1: {:.4f} | Test loss: {:.4f}".format(test_f1, test_loss)) if __name__ == "__main__": parser = argparse.ArgumentParser(description="GeniePath") parser.add_argument( "--gpu", type=int, default=-1, help="GPU Index. Default: -1, using CPU." ) parser.add_argument( "--hid_dim", type=int, default=256, help="Hidden layer dimension" ) parser.add_argument( "--num_layers", type=int, default=3, help="Number of GeniePath layers" ) parser.add_argument( "--max_epoch", type=int, default=1000, help="The max number of epochs. Default: 1000", ) parser.add_argument( "--lr", type=float, default=0.0004, help="Learning rate. Default: 0.0004", ) parser.add_argument( "--num_heads", type=int, default=1, help="Number of head in breadth function. Default: 1", ) parser.add_argument( "--residual", type=bool, default=False, help="Residual in GAT or not" ) parser.add_argument( "--batch_size", type=int, default=2, help="Batch size of graph dataloader", ) parser.add_argument( "--lazy", type=bool, default=False, help="Variant GeniePath-Lazy" ) args = parser.parse_args() print(args) th.manual_seed(16) main(args)