import itertools import time from collections import defaultdict as ddict import numpy as np import pandas as pd import torch import torch.nn.functional as F from catboost import CatBoostClassifier, CatBoostRegressor, Pool, sum_models from sklearn import preprocessing from sklearn.metrics import r2_score from tqdm import tqdm class BGNNPredictor: """ Description ----------- Boost GNN predictor for semi-supervised node classification or regression problems. Publication: https://arxiv.org/abs/2101.08543 Parameters ---------- gnn_model : nn.Module DGL implementation of GNN model. task: str, optional Regression or classification task. loss_fn : callable, optional Function that takes torch tensors, pred and true, and returns a scalar. trees_per_epoch : int, optional Number of GBDT trees to build each epoch. backprop_per_epoch : int, optional Number of backpropagation steps to make each epoch. lr : float, optional Learning rate of gradient descent optimizer. append_gbdt_pred : bool, optional Append GBDT predictions or replace original input node features. train_input_features : bool, optional Train original input node features. gbdt_depth : int, optional Depth of each tree in GBDT model. gbdt_lr : float, optional Learning rate of GBDT model. gbdt_alpha : int, optional Weight to combine previous and new GBDT trees. random_seed : int, optional random seed for GNN and GBDT models. Examples ---------- gnn_model = GAT(10, 20, num_heads=5), bgnn = BGNNPredictor(gnn_model) metrics = bgnn.fit(graph, X, y, train_mask, val_mask, test_mask, cat_features) """ def __init__( self, gnn_model, task="regression", loss_fn=None, trees_per_epoch=10, backprop_per_epoch=10, lr=0.01, append_gbdt_pred=True, train_input_features=False, gbdt_depth=6, gbdt_lr=0.1, gbdt_alpha=1, random_seed=0, ): self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu" ) self.model = gnn_model.to(self.device) self.task = task self.loss_fn = loss_fn self.trees_per_epoch = trees_per_epoch self.backprop_per_epoch = backprop_per_epoch self.lr = lr self.append_gbdt_pred = append_gbdt_pred self.train_input_features = train_input_features self.gbdt_depth = gbdt_depth self.gbdt_lr = gbdt_lr self.gbdt_alpha = gbdt_alpha self.random_seed = random_seed torch.manual_seed(random_seed) np.random.seed(random_seed) def init_gbdt_model(self, num_epochs, epoch): if self.task == "regression": catboost_model_obj = CatBoostRegressor catboost_loss_fn = "RMSE" else: if epoch == 0: # we predict multiclass probs at first epoch catboost_model_obj = CatBoostClassifier catboost_loss_fn = "MultiClass" else: # we predict the gradients for each class at epochs > 0 catboost_model_obj = CatBoostRegressor catboost_loss_fn = "MultiRMSE" return catboost_model_obj( iterations=num_epochs, depth=self.gbdt_depth, learning_rate=self.gbdt_lr, loss_function=catboost_loss_fn, random_seed=self.random_seed, nan_mode="Min", ) def fit_gbdt(self, pool, trees_per_epoch, epoch): gbdt_model = self.init_gbdt_model(trees_per_epoch, epoch) gbdt_model.fit(pool, verbose=False) return gbdt_model def append_gbdt_model(self, new_gbdt_model, weights): if self.gbdt_model is None: return new_gbdt_model return sum_models([self.gbdt_model, new_gbdt_model], weights=weights) def train_gbdt( self, gbdt_X_train, gbdt_y_train, cat_features, epoch, gbdt_trees_per_epoch, gbdt_alpha, ): pool = Pool(gbdt_X_train, gbdt_y_train, cat_features=cat_features) epoch_gbdt_model = self.fit_gbdt(pool, gbdt_trees_per_epoch, epoch) if epoch == 0 and self.task == "classification": self.base_gbdt = epoch_gbdt_model else: self.gbdt_model = self.append_gbdt_model( epoch_gbdt_model, weights=[1, gbdt_alpha] ) def update_node_features(self, node_features, X, original_X): # get predictions from gbdt model if self.task == "regression": predictions = np.expand_dims( self.gbdt_model.predict(original_X), axis=1 ) else: predictions = self.base_gbdt.predict_proba(original_X) if self.gbdt_model is not None: predictions_after_one = self.gbdt_model.predict(original_X) predictions += predictions_after_one # update node features with predictions if self.append_gbdt_pred: if self.train_input_features: predictions = np.append( node_features.detach().cpu().data[:, : -self.out_dim], predictions, axis=1, ) # replace old predictions with new predictions else: predictions = np.append( X, predictions, axis=1 ) # append original features with new predictions predictions = torch.from_numpy(predictions).to(self.device) node_features.data = predictions.float().data def update_gbdt_targets( self, node_features, node_features_before, train_mask ): return ( (node_features - node_features_before) .detach() .cpu() .numpy()[train_mask, -self.out_dim :] ) def init_node_features(self, X): node_features = torch.empty( X.shape[0], self.in_dim, requires_grad=True, device=self.device ) if self.append_gbdt_pred: node_features.data[:, : -self.out_dim] = torch.from_numpy( X.to_numpy(copy=True) ) return node_features def init_optimizer( self, node_features, optimize_node_features, learning_rate ): params = [self.model.parameters()] if optimize_node_features: params.append([node_features]) optimizer = torch.optim.Adam(itertools.chain(*params), lr=learning_rate) return optimizer def train_model(self, model_in, target_labels, train_mask, optimizer): y = target_labels[train_mask] self.model.train() logits = self.model(*model_in).squeeze() pred = logits[train_mask] if self.loss_fn is not None: loss = self.loss_fn(pred, y) else: if self.task == "regression": loss = torch.sqrt(F.mse_loss(pred, y)) elif self.task == "classification": loss = F.cross_entropy(pred, y.long()) else: raise NotImplemented( "Unknown task. Supported tasks: classification, regression." ) optimizer.zero_grad() loss.backward() optimizer.step() return loss def evaluate_model(self, logits, target_labels, mask): metrics = {} y = target_labels[mask] with torch.no_grad(): pred = logits[mask] if self.task == "regression": metrics["loss"] = torch.sqrt( F.mse_loss(pred, y).squeeze() + 1e-8 ) metrics["rmsle"] = torch.sqrt( F.mse_loss(torch.log(pred + 1), torch.log(y + 1)).squeeze() + 1e-8 ) metrics["mae"] = F.l1_loss(pred, y) metrics["r2"] = torch.Tensor( [r2_score(y.cpu().numpy(), pred.cpu().numpy())] ) elif self.task == "classification": metrics["loss"] = F.cross_entropy(pred, y.long()) metrics["accuracy"] = torch.Tensor( [(y == pred.max(1)[1]).sum().item() / y.shape[0]] ) return metrics def train_and_evaluate( self, model_in, target_labels, train_mask, val_mask, test_mask, optimizer, metrics, gnn_passes_per_epoch, ): loss = None for _ in range(gnn_passes_per_epoch): loss = self.train_model( model_in, target_labels, train_mask, optimizer ) self.model.eval() logits = self.model(*model_in).squeeze() train_results = self.evaluate_model(logits, target_labels, train_mask) val_results = self.evaluate_model(logits, target_labels, val_mask) test_results = self.evaluate_model(logits, target_labels, test_mask) for metric_name in train_results: metrics[metric_name].append( ( train_results[metric_name].detach().item(), val_results[metric_name].detach().item(), test_results[metric_name].detach().item(), ) ) return loss def update_early_stopping( self, metrics, epoch, best_metric, best_val_epoch, epochs_since_last_best_metric, metric_name, lower_better=False, ): train_metric, val_metric, test_metric = metrics[metric_name][-1] if (lower_better and val_metric < best_metric[1]) or ( not lower_better and val_metric > best_metric[1] ): best_metric = metrics[metric_name][-1] best_val_epoch = epoch epochs_since_last_best_metric = 0 else: epochs_since_last_best_metric += 1 return best_metric, best_val_epoch, epochs_since_last_best_metric def log_epoch( self, pbar, metrics, epoch, loss, epoch_time, logging_epochs, metric_name="loss", ): train_metric, val_metric, test_metric = metrics[metric_name][-1] if epoch and epoch % logging_epochs == 0: pbar.set_description( "Epoch {:05d} | Loss {:.3f} | Loss {:.3f}/{:.3f}/{:.3f} | Time {:.4f}".format( epoch, loss, train_metric, val_metric, test_metric, epoch_time, ) ) def fit( self, graph, X, y, train_mask, val_mask, test_mask, original_X=None, cat_features=None, num_epochs=100, patience=10, logging_epochs=1, metric_name="loss", ): """ :param graph : dgl.DGLGraph Input graph :param X : pd.DataFrame Input node features. Each column represents one input feature. Each row is a node. Values in dataframe are numerical, after preprocessing. :param y : pd.DataFrame Input node targets. Each column represents one target. Each row is a node (order of nodes should be the same as in X). :param train_mask : list[int] Node indexes (rows) that belong to train set. :param val_mask : list[int] Node indexes (rows) that belong to validation set. :param test_mask : list[int] Node indexes (rows) that belong to test set. :param original_X : pd.DataFrame, optional Input node features before preprocessing. Each column represents one input feature. Each row is a node. Values in dataframe can be of any type, including categorical (e.g. string, bool) or missing values (None). This is useful if you want to preprocess X with GBDT model. :param cat_features: list[int] Feature indexes (columns) which are categorical features. :param num_epochs : int Number of epochs to run. :param patience : int Number of epochs to wait until early stopping. :param logging_epochs : int Log every n epoch. :param metric_name : str Metric to use for early stopping. :param normalize_features : bool If to normalize original input features X (column wise). :param replace_na: bool If to replace missing values (None) in X. :return: metrics evaluated during training """ # initialize for early stopping and metrics if metric_name in ["r2", "accuracy"]: best_metric = [np.float("-inf")] * 3 # for train/val/test else: best_metric = [np.float("inf")] * 3 # for train/val/test best_val_epoch = 0 epochs_since_last_best_metric = 0 metrics = ddict(list) if cat_features is None: cat_features = [] if self.task == "regression": self.out_dim = y.shape[1] elif self.task == "classification": self.out_dim = len(set(y.iloc[test_mask, 0])) self.in_dim = ( self.out_dim + X.shape[1] if self.append_gbdt_pred else self.out_dim ) if original_X is None: original_X = X.copy() cat_features = [] gbdt_X_train = original_X.iloc[train_mask] gbdt_y_train = y.iloc[train_mask] gbdt_alpha = self.gbdt_alpha self.gbdt_model = None node_features = self.init_node_features(X) optimizer = self.init_optimizer( node_features, optimize_node_features=True, learning_rate=self.lr ) y = ( torch.from_numpy(y.to_numpy(copy=True)) .float() .squeeze() .to(self.device) ) graph = graph.to(self.device) pbar = tqdm(range(num_epochs)) for epoch in pbar: start2epoch = time.time() # gbdt part self.train_gbdt( gbdt_X_train, gbdt_y_train, cat_features, epoch, self.trees_per_epoch, gbdt_alpha, ) self.update_node_features(node_features, X, original_X) node_features_before = node_features.clone() model_in = (graph, node_features) loss = self.train_and_evaluate( model_in, y, train_mask, val_mask, test_mask, optimizer, metrics, self.backprop_per_epoch, ) gbdt_y_train = self.update_gbdt_targets( node_features, node_features_before, train_mask ) self.log_epoch( pbar, metrics, epoch, loss, time.time() - start2epoch, logging_epochs, metric_name=metric_name, ) # check early stopping ( best_metric, best_val_epoch, epochs_since_last_best_metric, ) = self.update_early_stopping( metrics, epoch, best_metric, best_val_epoch, epochs_since_last_best_metric, metric_name, lower_better=(metric_name not in ["r2", "accuracy"]), ) if patience and epochs_since_last_best_metric > patience: break if np.isclose(gbdt_y_train.sum(), 0.0): print("Node embeddings do not change anymore. Stopping...") break print( "Best {} at iteration {}: {:.3f}/{:.3f}/{:.3f}".format( metric_name, best_val_epoch, *best_metric ) ) return metrics def predict(self, graph, X, test_mask): graph = graph.to(self.device) node_features = torch.empty(X.shape[0], self.in_dim).to(self.device) self.update_node_features(node_features, X, X) logits = self.model(graph, node_features).squeeze() if self.task == "regression": return logits[test_mask] else: return logits[test_mask].max(1)[1] def plot_interactive( self, metrics, legend, title, logx=False, logy=False, metric_name="loss", start_from=0, ): import plotly.graph_objects as go metric_results = metrics[metric_name] xs = [list(range(len(metric_results)))] * len(metric_results[0]) ys = list(zip(*metric_results)) fig = go.Figure() for i in range(len(ys)): fig.add_trace( go.Scatter( x=xs[i][start_from:], y=ys[i][start_from:], mode="lines+markers", name=legend[i], ) ) fig.update_layout( title=title, title_x=0.5, xaxis_title="Epoch", yaxis_title=metric_name, font=dict( size=40, ), height=600, ) if logx: fig.update_layout(xaxis_type="log") if logy: fig.update_layout(yaxis_type="log") fig.show()