# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from pathlib import Path from typing import Tuple from megatron.training.global_vars import get_wandb_writer from megatron.training.utils import print_rank_last def _get_wandb_artifact_tracker_filename(save_dir: str) -> Path: """Wandb artifact tracker file rescords the latest artifact wandb entity and project""" return Path(save_dir) / "latest_wandb_artifact_path.txt" def _get_artifact_name_and_version(save_dir: Path, checkpoint_path: Path) -> Tuple[str, str]: return save_dir.stem, checkpoint_path.stem def on_save_checkpoint_success(checkpoint_path: str, tracker_filename: str, save_dir: str, iteration: int) -> None: """Function to be called after checkpointing succeeds and checkpoint is persisted for logging it as an artifact in W&B Args: checkpoint_path (str): path of the saved checkpoint tracker_filename (str): path of the tracker filename for the checkpoint iteration save_dir (str): path of the root save folder for all checkpoints iteration (int): iteration of the checkpoint """ wandb_writer = get_wandb_writer() if wandb_writer: metadata = {"iteration": iteration} artifact_name, artifact_version = _get_artifact_name_and_version(Path(save_dir), Path(checkpoint_path)) artifact = wandb_writer.Artifact(artifact_name, type="model", metadata=metadata) artifact.add_reference(f"file://{checkpoint_path}", checksum=False) artifact.add_file(tracker_filename) wandb_writer.run.log_artifact(artifact, aliases=[artifact_version]) wandb_tracker_filename = _get_wandb_artifact_tracker_filename(save_dir) wandb_tracker_filename.write_text(f"{wandb_writer.run.entity}/{wandb_writer.run.project}") def on_load_checkpoint_success(checkpoint_path: str, load_dir: str) -> None: """Function to be called after succesful loading of a checkpoint, for aggregation and logging it to W&B Args: checkpoint_path (str): path of the loaded checkpoint load_dir (str): path of the root save folder for all checkpoints iteration (int): iteration of the checkpoint """ wandb_writer = get_wandb_writer() if wandb_writer: try: artifact_name, artifact_version = _get_artifact_name_and_version(Path(load_dir), Path(checkpoint_path)) wandb_tracker_filename = _get_wandb_artifact_tracker_filename(load_dir) artifact_path = "" if wandb_tracker_filename.is_file(): artifact_path = wandb_tracker_filename.read_text().strip() artifact_path = f"{artifact_path}/" wandb_writer.run.use_artifact(f"{artifact_path}{artifact_name}:{artifact_version}") except Exception: print_rank_last(f" failed to find checkpoint {checkpoint_path} in wandb")