import itertools import time import numpy as np import torch from catboost import Pool, CatBoostClassifier, CatBoostRegressor, sum_models from tqdm import tqdm from collections import defaultdict as ddict import pandas as pd from sklearn import preprocessing import torch.nn.functional as F from sklearn.metrics import r2_score class BGNNPredictor: ''' Description ----------- Boost GNN predictor for semi-supervised node classification or regression problems. Publication: https://arxiv.org/abs/2101.08543 Parameters ---------- gnn_model : nn.Module DGL implementation of GNN model. task: str, optional Regression or classification task. loss_fn : callable, optional Function that takes torch tensors, pred and true, and returns a scalar. trees_per_epoch : int, optional Number of GBDT trees to build each epoch. backprop_per_epoch : int, optional Number of backpropagation steps to make each epoch. lr : float, optional Learning rate of gradient descent optimizer. append_gbdt_pred : bool, optional Append GBDT predictions or replace original input node features. train_input_features : bool, optional Train original input node features. gbdt_depth : int, optional Depth of each tree in GBDT model. gbdt_lr : float, optional Learning rate of GBDT model. gbdt_alpha : int, optional Weight to combine previous and new GBDT trees. random_seed : int, optional random seed for GNN and GBDT models. Examples ---------- gnn_model = GAT(10, 20, num_heads=5), bgnn = BGNNPredictor(gnn_model) metrics = bgnn.fit(graph, X, y, train_mask, val_mask, test_mask, cat_features) ''' def __init__(self, gnn_model, task = 'regression', loss_fn = None, trees_per_epoch = 10, backprop_per_epoch = 10, lr=0.01, append_gbdt_pred = True, train_input_features = False, gbdt_depth=6, gbdt_lr=0.1, gbdt_alpha = 1, random_seed = 0 ): self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.model = gnn_model.to(self.device) self.task = task self.loss_fn = loss_fn self.trees_per_epoch = trees_per_epoch self.backprop_per_epoch = backprop_per_epoch self.lr = lr self.append_gbdt_pred = append_gbdt_pred self.train_input_features = train_input_features self.gbdt_depth = gbdt_depth self.gbdt_lr = gbdt_lr self.gbdt_alpha = gbdt_alpha self.random_seed = random_seed torch.manual_seed(random_seed) np.random.seed(random_seed) def init_gbdt_model(self, num_epochs, epoch): if self.task == 'regression': catboost_model_obj = CatBoostRegressor catboost_loss_fn = 'RMSE' else: if epoch == 0: # we predict multiclass probs at first epoch catboost_model_obj = CatBoostClassifier catboost_loss_fn = 'MultiClass' else: # we predict the gradients for each class at epochs > 0 catboost_model_obj = CatBoostRegressor catboost_loss_fn = 'MultiRMSE' return catboost_model_obj(iterations=num_epochs, depth=self.gbdt_depth, learning_rate=self.gbdt_lr, loss_function=catboost_loss_fn, random_seed=self.random_seed, nan_mode='Min') def fit_gbdt(self, pool, trees_per_epoch, epoch): gbdt_model = self.init_gbdt_model(trees_per_epoch, epoch) gbdt_model.fit(pool, verbose=False) return gbdt_model def append_gbdt_model(self, new_gbdt_model, weights): if self.gbdt_model is None: return new_gbdt_model return sum_models([self.gbdt_model, new_gbdt_model], weights=weights) def train_gbdt(self, gbdt_X_train, gbdt_y_train, cat_features, epoch, gbdt_trees_per_epoch, gbdt_alpha): pool = Pool(gbdt_X_train, gbdt_y_train, cat_features=cat_features) epoch_gbdt_model = self.fit_gbdt(pool, gbdt_trees_per_epoch, epoch) if epoch == 0 and self.task=='classification': self.base_gbdt = epoch_gbdt_model else: self.gbdt_model = self.append_gbdt_model(epoch_gbdt_model, weights=[1, gbdt_alpha]) def update_node_features(self, node_features, X, original_X): # get predictions from gbdt model if self.task == 'regression': predictions = np.expand_dims(self.gbdt_model.predict(original_X), axis=1) else: predictions = self.base_gbdt.predict_proba(original_X) if self.gbdt_model is not None: predictions_after_one = self.gbdt_model.predict(original_X) predictions += predictions_after_one # update node features with predictions if self.append_gbdt_pred: if self.train_input_features: predictions = np.append(node_features.detach().cpu().data[:, :-self.out_dim], predictions, axis=1) # replace old predictions with new predictions else: predictions = np.append(X, predictions, axis=1) # append original features with new predictions predictions = torch.from_numpy(predictions).to(self.device) node_features.data = predictions.float().data def update_gbdt_targets(self, node_features, node_features_before, train_mask): return (node_features - node_features_before).detach().cpu().numpy()[train_mask, -self.out_dim:] def init_node_features(self, X): node_features = torch.empty(X.shape[0], self.in_dim, requires_grad=True, device=self.device) if self.append_gbdt_pred: node_features.data[:, :-self.out_dim] = torch.from_numpy(X.to_numpy(copy=True)) return node_features def init_optimizer(self, node_features, optimize_node_features, learning_rate): params = [self.model.parameters()] if optimize_node_features: params.append([node_features]) optimizer = torch.optim.Adam(itertools.chain(*params), lr=learning_rate) return optimizer def train_model(self, model_in, target_labels, train_mask, optimizer): y = target_labels[train_mask] self.model.train() logits = self.model(*model_in).squeeze() pred = logits[train_mask] if self.loss_fn is not None: loss = self.loss_fn(pred, y) else: if self.task == 'regression': loss = torch.sqrt(F.mse_loss(pred, y)) elif self.task == 'classification': loss = F.cross_entropy(pred, y.long()) else: raise NotImplemented("Unknown task. Supported tasks: classification, regression.") optimizer.zero_grad() loss.backward() optimizer.step() return loss def evaluate_model(self, logits, target_labels, mask): metrics = {} y = target_labels[mask] with torch.no_grad(): pred = logits[mask] if self.task == 'regression': metrics['loss'] = torch.sqrt(F.mse_loss(pred, y).squeeze() + 1e-8) metrics['rmsle'] = torch.sqrt(F.mse_loss(torch.log(pred + 1), torch.log(y + 1)).squeeze() + 1e-8) metrics['mae'] = F.l1_loss(pred, y) metrics['r2'] = torch.Tensor([r2_score(y.cpu().numpy(), pred.cpu().numpy())]) elif self.task == 'classification': metrics['loss'] = F.cross_entropy(pred, y.long()) metrics['accuracy'] = torch.Tensor([(y == pred.max(1)[1]).sum().item()/y.shape[0]]) return metrics def train_and_evaluate(self, model_in, target_labels, train_mask, val_mask, test_mask, optimizer, metrics, gnn_passes_per_epoch): loss = None for _ in range(gnn_passes_per_epoch): loss = self.train_model(model_in, target_labels, train_mask, optimizer) self.model.eval() logits = self.model(*model_in).squeeze() train_results = self.evaluate_model(logits, target_labels, train_mask) val_results = self.evaluate_model(logits, target_labels, val_mask) test_results = self.evaluate_model(logits, target_labels, test_mask) for metric_name in train_results: metrics[metric_name].append((train_results[metric_name].detach().item(), val_results[metric_name].detach().item(), test_results[metric_name].detach().item() )) return loss def update_early_stopping(self, metrics, epoch, best_metric, best_val_epoch, epochs_since_last_best_metric, metric_name, lower_better=False): train_metric, val_metric, test_metric = metrics[metric_name][-1] if (lower_better and val_metric < best_metric[1]) or (not lower_better and val_metric > best_metric[1]): best_metric = metrics[metric_name][-1] best_val_epoch = epoch epochs_since_last_best_metric = 0 else: epochs_since_last_best_metric += 1 return best_metric, best_val_epoch, epochs_since_last_best_metric def log_epoch(self, pbar, metrics, epoch, loss, epoch_time, logging_epochs, metric_name='loss'): train_metric, val_metric, test_metric = metrics[metric_name][-1] if epoch and epoch % logging_epochs == 0: pbar.set_description( "Epoch {:05d} | Loss {:.3f} | Loss {:.3f}/{:.3f}/{:.3f} | Time {:.4f}".format(epoch, loss, train_metric, val_metric, test_metric, epoch_time)) def fit(self, graph, X, y, train_mask, val_mask, test_mask, original_X = None, cat_features = None, num_epochs=100, patience=10, logging_epochs=1, metric_name='loss', ): ''' :param graph : dgl.DGLGraph Input graph :param X : pd.DataFrame Input node features. Each column represents one input feature. Each row is a node. Values in dataframe are numerical, after preprocessing. :param y : pd.DataFrame Input node targets. Each column represents one target. Each row is a node (order of nodes should be the same as in X). :param train_mask : list[int] Node indexes (rows) that belong to train set. :param val_mask : list[int] Node indexes (rows) that belong to validation set. :param test_mask : list[int] Node indexes (rows) that belong to test set. :param original_X : pd.DataFrame, optional Input node features before preprocessing. Each column represents one input feature. Each row is a node. Values in dataframe can be of any type, including categorical (e.g. string, bool) or missing values (None). This is useful if you want to preprocess X with GBDT model. :param cat_features: list[int] Feature indexes (columns) which are categorical features. :param num_epochs : int Number of epochs to run. :param patience : int Number of epochs to wait until early stopping. :param logging_epochs : int Log every n epoch. :param metric_name : str Metric to use for early stopping. :param normalize_features : bool If to normalize original input features X (column wise). :param replace_na: bool If to replace missing values (None) in X. :return: metrics evaluated during training ''' # initialize for early stopping and metrics if metric_name in ['r2', 'accuracy']: best_metric = [np.float('-inf')] * 3 # for train/val/test else: best_metric = [np.float('inf')] * 3 # for train/val/test best_val_epoch = 0 epochs_since_last_best_metric = 0 metrics = ddict(list) if cat_features is None: cat_features = [] if self.task == 'regression': self.out_dim = y.shape[1] elif self.task == 'classification': self.out_dim = len(set(y.iloc[test_mask, 0])) self.in_dim = self.out_dim + X.shape[1] if self.append_gbdt_pred else self.out_dim if original_X is None: original_X = X.copy() cat_features = [] gbdt_X_train = original_X.iloc[train_mask] gbdt_y_train = y.iloc[train_mask] gbdt_alpha = self.gbdt_alpha self.gbdt_model = None node_features = self.init_node_features(X) optimizer = self.init_optimizer(node_features, optimize_node_features=True, learning_rate=self.lr) y = torch.from_numpy(y.to_numpy(copy=True)).float().squeeze().to(self.device) graph = graph.to(self.device) pbar = tqdm(range(num_epochs)) for epoch in pbar: start2epoch = time.time() # gbdt part self.train_gbdt(gbdt_X_train, gbdt_y_train, cat_features, epoch, self.trees_per_epoch, gbdt_alpha) self.update_node_features(node_features, X, original_X) node_features_before = node_features.clone() model_in=(graph, node_features) loss = self.train_and_evaluate(model_in, y, train_mask, val_mask, test_mask, optimizer, metrics, self.backprop_per_epoch) gbdt_y_train = self.update_gbdt_targets(node_features, node_features_before, train_mask) self.log_epoch(pbar, metrics, epoch, loss, time.time() - start2epoch, logging_epochs, metric_name=metric_name) # check early stopping best_metric, best_val_epoch, epochs_since_last_best_metric = \ self.update_early_stopping(metrics, epoch, best_metric, best_val_epoch, epochs_since_last_best_metric, metric_name, lower_better=(metric_name not in ['r2', 'accuracy'])) if patience and epochs_since_last_best_metric > patience: break if np.isclose(gbdt_y_train.sum(), 0.): print('Node embeddings do not change anymore. Stopping...') break print('Best {} at iteration {}: {:.3f}/{:.3f}/{:.3f}'.format(metric_name, best_val_epoch, *best_metric)) return metrics def predict(self, graph, X, test_mask): graph = graph.to(self.device) node_features = torch.empty(X.shape[0], self.in_dim).to(self.device) self.update_node_features(node_features, X, X) logits = self.model(graph, node_features).squeeze() if self.task == 'regression': return logits[test_mask] else: return logits[test_mask].max(1)[1] def plot_interactive(self, metrics, legend, title, logx=False, logy=False, metric_name='loss', start_from=0): import plotly.graph_objects as go metric_results = metrics[metric_name] xs = [list(range(len(metric_results)))] * len(metric_results[0]) ys = list(zip(*metric_results)) fig = go.Figure() for i in range(len(ys)): fig.add_trace(go.Scatter(x=xs[i][start_from:], y=ys[i][start_from:], mode='lines+markers', name=legend[i])) fig.update_layout( title=title, title_x=0.5, xaxis_title='Epoch', yaxis_title=metric_name, font=dict( size=40, ), height=600, ) if logx: fig.update_layout(xaxis_type="log") if logy: fig.update_layout(yaxis_type="log") fig.show()