import json import logging from typing import Dict, List, Optional import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import torch.distributed from torchmetrics import Metric plt.rcParams.update({"font.size": 8}) mpl_logger = logging.getLogger("matplotlib") mpl_logger.setLevel(logging.WARNING) # Only show WARNING and above colors = [ "#1f77b4", # muted blue "#d62728", # brick red "#7f7f7f", # middle gray "#2ca02c", # cooked asparagus green "#ff7f0e", # safety orange "#9467bd", # muted purple "#8c564b", # chestnut brown "#e377c2", # raspberry yogurt pink "#bcbd22", # curry yellow-green "#17becf", # blue-teal ] error_type = { "TotalRMSE": ( [("rmse_e", "RMSE E [meV]"), ("rmse_f", "RMSE F [meV / A]")], [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], ), "PerAtomRMSE": ( [("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_f", "RMSE F [meV / A]")], [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], ), "PerAtomRMSEstressvirials": ( [ ("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_f", "RMSE F [meV / A]"), ("rmse_stress", "RMSE Stress [meV / A^3]"), ], [ ("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]"), ("stress", "Stress [eV / A^3]"), ], ), "PerAtomMAEstressvirials": ( [ ("mae_e_per_atom", "MAE E/atom [meV]"), ("mae_f", "MAE F [meV / A]"), ("mae_stress", "MAE Stress [meV / A^3]"), ], [ ("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]"), ("stress", "Stress [eV / A^3]"), ], ), "TotalMAE": ( [("mae_e", "MAE E [meV]"), ("mae_f", "MAE F [meV / A]")], [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], ), "PerAtomMAE": ( [("mae_e_per_atom", "MAE E/atom [meV]"), ("mae_f", "MAE F [meV / A]")], [("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]")], ), "DipoleRMSE": ( [ ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), ("rel_rmse_f", "Relative MU RMSE [%]"), ], [("dipole", "Dipole per atom [Debye]")], ), "DipoleMAE": ( [("mae_mu", "MAE MU [mDebye]"), ("rel_mae_f", "Relative MU MAE [%]")], [("dipole", "Dipole per atom [Debye]")], ), "EnergyDipoleRMSE": ( [ ("rmse_e_per_atom", "RMSE E/atom [meV]"), ("rmse_f", "RMSE F [meV / A]"), ("rmse_mu_per_atom", "RMSE MU/atom [mDebye]"), ], [ ("energy", "Energy per atom [eV]"), ("force", "Force [eV / A]"), ("dipole", "Dipole per atom [Debye]"), ], ), } class TrainingPlotter: def __init__( self, results_dir: str, heads: List[str], table_type: str, train_valid_data: Dict, test_data: Dict, output_args: str, device: str, plot_frequency: int, distributed: bool = False, swa_start: Optional[int] = None, ): self.results_dir = results_dir self.heads = heads self.table_type = table_type self.train_valid_data = train_valid_data self.test_data = test_data self.output_args = output_args self.device = device self.plot_frequency = plot_frequency self.distributed = distributed self.swa_start = swa_start def plot(self, model_epoch: str, model: torch.nn.Module, rank: int) -> None: # All ranks process data through model_inference train_valid_dict = model_inference( self.train_valid_data, model, self.output_args, self.device, self.distributed, ) test_dict = model_inference( self.test_data, model, self.output_args, self.device, self.distributed ) # Only rank 0 creates and saves plots if rank != 0: return data = pd.DataFrame( results for results in parse_training_results(self.results_dir) ) labels, quantities = error_type[self.table_type] for head in self.heads: fig = plt.figure(layout="constrained", figsize=(10, 6)) fig.suptitle( f"Model loaded from epoch {model_epoch} ({head} head)", fontsize=16 ) subfigs = fig.subfigures(2, 1, height_ratios=[1, 1], hspace=0.05) axsTop = subfigs[0].subplots(1, 2, sharey=False) axsBottom = subfigs[1].subplots(1, len(quantities), sharey=False) plot_epoch_dependence(axsTop, data, head, model_epoch, labels) # Use the pre-computed results for plotting plot_inference_from_results( axsBottom, train_valid_dict, test_dict, head, quantities ) if self.swa_start is not None: # Add vertical lines to both axes for ax in axsTop: ax.axvline( self.swa_start, color="black", linestyle="dashed", linewidth=1, alpha=0.6, label="Stage Two Starts", ) stage = "stage_two" if self.swa_start < model_epoch else "stage_one" else: stage = "stage_one" axsTop[0].legend(loc="best") # Save the figure using the appropriate stage in the filename filename = f"{self.results_dir[:-4]}_{head}_{stage}.png" fig.savefig(filename, dpi=300, bbox_inches="tight") plt.close(fig) def parse_training_results(path: str) -> List[dict]: results = [] with open(path, mode="r", encoding="utf-8") as f: for line in f: try: d = json.loads(line.strip()) # Ensure it's valid JSON results.append(d) except json.JSONDecodeError: print( f"Skipping invalid line: {line.strip()}" ) # Handle non-JSON lines gracefully return results def plot_epoch_dependence( axes: np.ndarray, data: pd.DataFrame, head: str, model_epoch: str, labels: List[str] ) -> None: valid_data = ( data[data["mode"] == "eval"] .groupby(["mode", "epoch", "head"]) .agg(["mean", "std"]) .reset_index() ) valid_data = valid_data[valid_data["head"] == head] train_data = ( data[data["mode"] == "opt"] .groupby(["mode", "epoch"]) .agg(["mean", "std"]) .reset_index() ) # ---- Plot loss ---- ax = axes[0] ax.plot( train_data["epoch"], train_data["loss"]["mean"], color=colors[1], linewidth=1 ) ax.set_ylabel("Training Loss", color=colors[1]) ax.set_yscale("log") ax2 = ax.twinx() ax2.plot( valid_data["epoch"], valid_data["loss"]["mean"], color=colors[0], linewidth=1 ) ax2.set_ylabel("Validation Loss", color=colors[0]) ax2.set_yscale("log") ax.axvline( model_epoch, color="black", linestyle="solid", linewidth=1, alpha=0.8, label="Loaded Model", ) ax.set_xlabel("Epoch") ax.grid(True, linestyle="--", alpha=0.5) # ---- Plot selected keys ---- ax = axes[1] twin_axes = [] for i, label in enumerate(labels): color = colors[(i + 3)] key, axis_label = label if i == 0: main_ax = ax else: main_ax = ax.twinx() main_ax.spines.right.set_position(("outward", 60 * (i - 1))) twin_axes.append(main_ax) main_ax.plot( valid_data["epoch"], valid_data[key]["mean"] * 1e3, color=color, label=label, linewidth=1, ) main_ax.set_yscale("log") main_ax.set_ylabel(axis_label, color=color) main_ax.tick_params(axis="y", colors=color) ax.axvline( model_epoch, color="black", linestyle="solid", linewidth=1, alpha=0.8, label="Loaded Model", ) ax.set_xlabel("Epoch") ax.grid(True, linestyle="--", alpha=0.5) # INFERENCE========= def plot_inference_from_results( axes: np.ndarray, train_valid_dict: dict, test_dict: dict, head: str, quantities: List[str], ) -> None: for ax, quantity in zip(axes, quantities): key, label = quantity # Store legend handles to avoid duplicates legend_labels = {} # Plot train/valid data (each entry keeps its own name) for name, result in train_valid_dict.items(): if "train" in name: fixed_color_train_valid = colors[1] marker = "x" else: fixed_color_train_valid = colors[0] marker = "+" if head not in name: continue # Initialize scatter to None scatter = None if key == "energy" and "energy" in result: scatter = ax.scatter( result["energy"]["reference_per_atom"], result["energy"]["predicted_per_atom"], marker=marker, color=fixed_color_train_valid, label=name, ) elif key == "force" and "forces" in result: scatter = ax.scatter( result["forces"]["reference"], result["forces"]["predicted"], marker=marker, color=fixed_color_train_valid, label=name, ) elif key == "stress" and "stress" in result: scatter = ax.scatter( result["stress"]["reference"], result["stress"]["predicted"], marker=marker, color=fixed_color_train_valid, label=name, ) elif key == "virials" and "virials" in result: scatter = ax.scatter( result["virials"]["reference_per_atom"], result["virials"]["predicted_per_atom"], marker=marker, color=fixed_color_train_valid, label=name, ) elif key == "dipole" and "dipole" in result: scatter = ax.scatter( result["dipole"]["reference_per_atom"], result["dipole"]["predicted_per_atom"], marker=marker, color=fixed_color_train_valid, label=name, ) # Add each train/valid dataset's name to the legend if scatter was assigned if scatter is not None: legend_labels[name] = scatter fixed_color_test = colors[2] # Color for test dataset # Plot test data (single legend entry) for name, result in test_dict.items(): # Initialize scatter to None to avoid possibly used before assignment scatter = None if key == "energy" and "energy" in result: scatter = ax.scatter( result["energy"]["reference_per_atom"], result["energy"]["predicted_per_atom"], marker="o", color=fixed_color_test, label="Test", ) elif key == "force" and "forces" in result: scatter = ax.scatter( result["forces"]["reference"], result["forces"]["predicted"], marker="o", color=fixed_color_test, label="Test", ) elif key == "stress" and "stress" in result: scatter = ax.scatter( result["stress"]["reference"], result["stress"]["predicted"], marker="o", color=fixed_color_test, label="Test", ) elif key == "virials" and "virials" in result: scatter = ax.scatter( result["virials"]["reference_per_atom"], result["virials"]["predicted_per_atom"], marker="o", color=fixed_color_test, label="Test", ) elif key == "dipole" and "dipole" in result: scatter = ax.scatter( result["dipole"]["reference_per_atom"], result["dipole"]["predicted_per_atom"], marker="o", color=fixed_color_test, label="Test", ) # Only add to legend_labels if scatter was assigned if scatter is not None: legend_labels["Test"] = scatter # Add diagonal line for guide min_val = min(ax.get_xlim()[0], ax.get_ylim()[0]) max_val = max(ax.get_xlim()[1], ax.get_ylim()[1]) ax.plot( [min_val, max_val], [min_val, max_val], linestyle="--", color="black", alpha=0.7, ) # Set legend with unique entries (Test + individual train/valid names) if legend_labels: ax.legend( handles=legend_labels.values(), labels=legend_labels.keys(), loc="best" ) ax.set_xlabel(f"Reference {label}") ax.set_ylabel(f"MACE {label}") ax.grid(True, linestyle="--", alpha=0.5) def model_inference( all_data_loaders: dict, model: torch.nn.Module, output_args: Dict[str, bool], device: str, distributed: bool = False, ): for param in model.parameters(): param.requires_grad = False results_dict = {} for name in all_data_loaders: data_loader = all_data_loaders[name] logging.debug(f"Running inference on {name} dataset") scatter_metric = InferenceMetric().to(device) for batch in data_loader: batch = batch.to(device) batch_dict = batch.to_dict() output = model( batch_dict, training=False, compute_force=output_args.get("forces", False), compute_virials=output_args.get("virials", False), compute_stress=output_args.get("stress", False), ) results = scatter_metric(batch, output) if distributed: torch.distributed.barrier() results = scatter_metric.compute() results_dict[name] = results scatter_metric.reset() del data_loader for param in model.parameters(): param.requires_grad = True return results_dict def to_numpy(tensor: torch.Tensor) -> np.ndarray: return tensor.cpu().detach().numpy() class InferenceMetric(Metric): """Metric class for collecting reference and predicted values for scatterplot visualization.""" def __init__(self): super().__init__() # Raw values self.add_state("ref_energies", default=[], dist_reduce_fx="cat") self.add_state("pred_energies", default=[], dist_reduce_fx="cat") self.add_state("ref_forces", default=[], dist_reduce_fx="cat") self.add_state("pred_forces", default=[], dist_reduce_fx="cat") self.add_state("ref_stress", default=[], dist_reduce_fx="cat") self.add_state("pred_stress", default=[], dist_reduce_fx="cat") self.add_state("ref_virials", default=[], dist_reduce_fx="cat") self.add_state("pred_virials", default=[], dist_reduce_fx="cat") self.add_state("ref_dipole", default=[], dist_reduce_fx="cat") self.add_state("pred_dipole", default=[], dist_reduce_fx="cat") # Per-atom normalized values self.add_state("ref_energies_per_atom", default=[], dist_reduce_fx="cat") self.add_state("pred_energies_per_atom", default=[], dist_reduce_fx="cat") self.add_state("ref_virials_per_atom", default=[], dist_reduce_fx="cat") self.add_state("pred_virials_per_atom", default=[], dist_reduce_fx="cat") self.add_state("ref_dipole_per_atom", default=[], dist_reduce_fx="cat") self.add_state("pred_dipole_per_atom", default=[], dist_reduce_fx="cat") # Store atom counts for each configuration self.add_state("atom_counts", default=[], dist_reduce_fx="cat") # Counters self.add_state("n_energy", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_forces", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_stress", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_virials", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("n_dipole", default=torch.tensor(0.0), dist_reduce_fx="sum") def update(self, batch, output): # pylint: disable=arguments-differ """Update metric states with new batch data.""" # Calculate number of atoms per configuration atoms_per_config = batch.ptr[1:] - batch.ptr[:-1] self.atom_counts.append(atoms_per_config) # Energy if output.get("energy") is not None and batch.energy is not None: self.n_energy += 1.0 self.ref_energies.append(batch.energy) self.pred_energies.append(output["energy"]) # Per-atom normalization self.ref_energies_per_atom.append(batch.energy / atoms_per_config) self.pred_energies_per_atom.append(output["energy"] / atoms_per_config) # Forces if output.get("forces") is not None and batch.forces is not None: self.n_forces += 1.0 self.ref_forces.append(batch.forces) self.pred_forces.append(output["forces"]) # Stress if output.get("stress") is not None and batch.stress is not None: self.n_stress += 1.0 self.ref_stress.append(batch.stress) self.pred_stress.append(output["stress"]) # Virials if output.get("virials") is not None and batch.virials is not None: self.n_virials += 1.0 self.ref_virials.append(batch.virials) self.pred_virials.append(output["virials"]) # Per-atom normalization atoms_per_config_3d = atoms_per_config.view(-1, 1, 1) self.ref_virials_per_atom.append(batch.virials / atoms_per_config_3d) self.pred_virials_per_atom.append(output["virials"] / atoms_per_config_3d) # Dipole if output.get("dipole") is not None and batch.dipole is not None: self.n_dipole += 1.0 self.ref_dipole.append(batch.dipole) self.pred_dipole.append(output["dipole"]) atoms_per_config_3d = atoms_per_config.view(-1, 1) self.ref_dipole_per_atom.append(batch.dipole / atoms_per_config_3d) self.pred_dipole_per_atom.append(output["dipole"] / atoms_per_config_3d) def _process_data(self, ref_list, pred_list): # Handle different possible states of ref_list and pred_list in distributed mode # Check if this is a list type object if isinstance(ref_list, (list, tuple)): if len(ref_list) == 0: return None, None ref = torch.cat(ref_list).reshape(-1) pred = torch.cat(pred_list).reshape(-1) # Handle case where ref_list is already a tensor (happens after reset in distributed mode) elif isinstance(ref_list, torch.Tensor): ref = ref_list.reshape(-1) pred = pred_list.reshape(-1) # Handle other possible types else: return None, None return to_numpy(ref), to_numpy(pred) def compute(self): """Compute final results for scatterplot.""" results = {} # Process energies if self.n_energy: ref_e, pred_e = self._process_data(self.ref_energies, self.pred_energies) ref_e_pa, pred_e_pa = self._process_data( self.ref_energies_per_atom, self.pred_energies_per_atom ) results["energy"] = { "reference": ref_e, "predicted": pred_e, "reference_per_atom": ref_e_pa, "predicted_per_atom": pred_e_pa, } # Process forces if self.n_forces: ref_f, pred_f = self._process_data(self.ref_forces, self.pred_forces) results["forces"] = { "reference": ref_f, "predicted": pred_f, } # Process stress if self.n_stress: ref_s, pred_s = self._process_data(self.ref_stress, self.pred_stress) results["stress"] = { "reference": ref_s, "predicted": pred_s, } # Process virials if self.n_virials: ref_v, pred_v = self._process_data(self.ref_virials, self.pred_virials) ref_v_pa, pred_v_pa = self._process_data( self.ref_virials_per_atom, self.pred_virials_per_atom ) results["virials"] = { "reference": ref_v, "predicted": pred_v, "reference_per_atom": ref_v_pa, "predicted_per_atom": pred_v_pa, } # Process dipoles if self.n_dipole: ref_d, pred_d = self._process_data(self.ref_dipole, self.pred_dipole) ref_d_pa, pred_d_pa = self._process_data( self.ref_dipole_per_atom, self.pred_dipole_per_atom ) results["dipole"] = { "reference": ref_d, "predicted": pred_d, "reference_per_atom": ref_d_pa, "predicted_per_atom": pred_d_pa, } return results