# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """Sample Generate GPT.""" import functools import os import sys import warnings import json sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) import modelopt from modelopt.torch.speculative.plugins.megatron_eagle import MegatronARValidation import torch from datasets import load_dataset from tqdm import tqdm from megatron.core import mpu from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region from megatron.post_training.arguments import add_modelopt_args from megatron.post_training.checkpointing import load_modelopt_checkpoint from megatron.post_training.model_provider import model_provider from megatron.post_training.utils import get_mtbench_chat_data from megatron.training import get_args, get_model, get_tokenizer, initialize_megatron from megatron.training.utils import get_ltor_masks_and_position_ids, print_rank_0, unwrap_model warnings.filterwarnings('ignore') def add_ar_validation_args(parser): """Add additional arguments for ModelOpt acceptance rate validation.""" group = parser.add_argument_group(title='ModelOpt ar validation') group.add_argument( "--osl", type=int, default=64, help="Output sequence length." ) parser.add_argument( "--prompts-path", type=str, default=None, help="Path to the prompts json file. If not provided, MTBench will be used.", ) parser.add_argument( "--ground-truth-path", type=str, default=None, help="Path to the ground truth pt file.", ) parser.add_argument( "--steps", type=int, default=1, help="Only used in EAGLE." ) parser.add_argument( "--save-ground-truth-path", type=str, default=None, help="Save path for the ground truth pt file.", ) add_modelopt_args(parser) return parser def check_arguments(): """Checking user arguments.""" args = get_args() if args.num_layers_per_virtual_pipeline_stage is not None: print_rank_0("Interleaved pipeline schedule is not yet supported for text generation.") exit() if hasattr(args, 'moe_grouped_gemm') and args.moe_grouped_gemm == True: print_rank_0("WARNING: Forcing moe_grouped_gemm to False for PTQ and export.") args.moe_grouped_gemm = False def get_current_memory_info(): remaining_mem, total_mem = torch.cuda.mem_get_info() info = "rank {:02} memory remaining {:03}% ({}/{} MB) ".format( torch.distributed.get_rank(), int(remaining_mem * 100 / total_mem), remaining_mem // 1048576, total_mem // 1048576, ) return info def report_current_memory_info(): """Report current memory usage.""" print(get_current_memory_info(), flush=True) torch.distributed.barrier() if __name__ == "__main__": initialize_megatron( extra_args_provider=add_ar_validation_args, args_defaults={ 'tokenizer_type': 'HuggingFaceTokenizer', 'no_load_rng': True, 'no_load_optim': True, }, ) check_arguments() args = get_args() if not args.prompts_path: dataset = get_mtbench_chat_data() prompts = [[sample["conversations"][0]] for sample in dataset] else: with open(args.prompts_path, "r") as f: prompts = [json.loads(line) for line in f] if args.ground_truth_path is not None: ground_truth = torch.load(args.ground_truth_path) ground_truth = [gt.to(torch.cuda.current_device()) for gt in ground_truth] else: ground_truth = [None for _ in range(len(prompts))] tokenizer = get_tokenizer()._tokenizer model = get_model(functools.partial(model_provider, parallel_output=True), wrap_with_ddp=False) report_current_memory_info() if args.load is not None: load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights) print_rank_0("Done loading checkpoint") unwrapped_model = unwrap_model(model)[0] unwrapped_model.eval() validator = MegatronARValidation(unwrapped_model, tokenizer) gt = [] ar = [] for prompt, truth in zip(prompts, ground_truth): output = validator.validate(args.osl, prompt, ground_truth=truth, steps=args.steps) gt.append(output[0]) ar.append(output[1]) print_rank_0("Acceptance Rate: " + str(ar)) print_rank_0("Average: " + str(sum(ar)/len(ar))) if args.save_ground_truth_path is not None: torch.save(gt, args.save_ground_truth_path)