# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """Sample Generate GPT.""" import functools import os import sys from pathlib import Path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))) import modelopt.torch.quantization as mtq import torch from datasets import load_dataset from modelopt.torch.utils.distributed import set_data_parallel_group, set_tensor_parallel_group from tqdm import tqdm # [ModelOpt]: changing the default model provider to the ModelOpt version from megatron.core import mpu from megatron.inference.arguments import add_modelopt_args from megatron.inference.checkpointing import load_modelopt_checkpoint from megatron.inference.gpt.model_provider import model_provider from megatron.inference.text_generation import generate_and_post_process from megatron.training import get_args, get_model, initialize_megatron from megatron.training.checkpointing import save_checkpoint from megatron.training.utils import print_rank_0, unwrap_model QUANT_CFG_CHOICES = { "int8": mtq.INT8_DEFAULT_CFG, "int8_sq": mtq.INT8_SMOOTHQUANT_CFG, "fp8": mtq.FP8_DEFAULT_CFG, "int4_awq": mtq.INT4_AWQ_CFG, "w4a8_awq": mtq.W4A8_AWQ_BETA_CFG, "int4": mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, } def add_trtllm_ckpt_export_args(parser): """Add additional arguments for TensorRT-LLM.""" group = parser.add_argument_group(title="trtllm") group.add_argument( "--export-dir", type=str, help="The output TensorRT-LLM checkpoint.", ) group.add_argument( "--decoder", type=str, choices=["gptnext", 'llama'], help="The decoder type of the model.", ) group.add_argument( "--inference-tensor-parallel", type=int, help="Tensor parallel for the inference time, can be different from the training config.", default=1, ) def add_text_generate_ptq_args(parser): """Add additional arguments for ModelOpt text generation PTQ.""" group = parser.add_argument_group(title='ModelOpt text generation ptq') group.add_argument( "--calib-dataset", type=str, default="cnn_dailymail", help="Calibration datasets from HuggingFace datasets.", ) group.add_argument( "--calib-batch-size", type=int, default=4, help="Batch size to use for ptq calibration." ) group.add_argument( "--calib-size", type=int, default=512, help="Samples to use for ptq calibration." ) parser.add_argument( "--prompts", type=str, default=( "Born in north-east France, Soyer trained as a|Born in California, Soyer trained as a" ), help="Input texts. Please use | to separate different batches.", ) add_modelopt_args(parser) add_trtllm_ckpt_export_args(parser) return parser def get_calib_dataloader( data="cnn_dailymail", batch_size=4, calib_size=512, max_sequence_length=512 ): if data == "pileval": dataset = load_dataset( "json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train" ) text_column = "text" elif data == "wikitext": dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") text_column = "text" elif data == "cnn_dailymail": dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") text_column = "article" calib_size = max(min(len(dataset), calib_size), batch_size) for i in range(calib_size // batch_size): batch = dataset[i * batch_size : (i + 1) * batch_size][text_column] for j in range(len(batch)): batch[j] = batch[j][:max_sequence_length] yield batch if __name__ == "__main__": initialize_megatron( extra_args_provider=add_text_generate_ptq_args, args_defaults={ 'tokenizer_type': 'GPT2BPETokenizer', 'no_load_rng': True, 'no_load_optim': True, }, ) args = get_args() if args.num_layers_per_virtual_pipeline_stage is not None: print_rank_0("Interleaved pipeline schedule is not yet supported for text generation.") exit() print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text generation.") args.exit_on_missing_checkpoint = True # Set up model and load checkpoint # [ModelOpt]: make sure that output logits are allgathered. text_generation_model_provider = functools.partial(model_provider, parallel_output=False) model = get_model(text_generation_model_provider, wrap_with_ddp=False) if args.load is not None: load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights) print_rank_0("Done loading checkpoint") # Removing virtual pipeline parallel and other wrapper assert len(model) == 1, "Above condition should have caught this" unwrapped_model = unwrap_model(model) all_prompts = args.prompts.split("|") def custom_prompt_forward_loop_func(model): for prompt in tqdm(all_prompts): if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: ( prompts_plus_generations, prompts_plus_generations_segments, logprobs, _, ) = generate_and_post_process( model, prompts=[prompt], tokens_to_generate=128, return_output_log_probs=True, temperature=1.0, ) print_rank_0(prompts_plus_generations) else: generate_and_post_process(model) def hf_dataset_forword_loop_func(model): dataloader = get_calib_dataloader(args.calib_dataset, args.calib_batch_size, args.calib_size) for prompts in tqdm(dataloader, total=args.calib_size//args.calib_batch_size): if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0: ( prompts_plus_generations, prompts_plus_generations_segments, logprobs, _, ) = generate_and_post_process( model, prompts=prompts, tokens_to_generate=0, return_output_log_probs=True, temperature=1.0, ) else: generate_and_post_process(model) ptq_forward_loop_func = custom_prompt_forward_loop_func if args.calib_dataset is not None: ptq_forward_loop_func = hf_dataset_forword_loop_func # Setting data parallel and tensor parallel group set_data_parallel_group(mpu.get_data_parallel_group()) set_tensor_parallel_group(mpu.get_tensor_model_parallel_group()) if args.export_quant_cfg in QUANT_CFG_CHOICES: mtq_config = QUANT_CFG_CHOICES[args.export_quant_cfg] if "*output_layer*" not in mtq_config["quant_cfg"]: mtq_config["quant_cfg"]["*output_layer*"] = {"enable": False} if "awq" in args.export_quant_cfg: weight_quantizer = mtq_config["quant_cfg"]["*weight_quantizer"] # type: ignore if isinstance(weight_quantizer, list): weight_quantizer = weight_quantizer[0] weight_quantizer["block_sizes"][-1] = 128 print_rank_0("Quantizing the model...") mtq.quantize(unwrapped_model[0], mtq_config, ptq_forward_loop_func) custom_prompt_forward_loop_func(model[0]) if args.save is not None and args.export_quant_cfg in QUANT_CFG_CHOICES: save_checkpoint(1, unwrapped_model, None, None, 0) print_rank_0(f"Fake Quantized Model:\n {unwrapped_model[0]}") if args.export_dir: assert args.decoder in ["gptnext", "llama"], f"Decoder type {args.decoder} not supported." Path(args.export_dir).mkdir(parents=True, exist_ok=True) print_rank_0("Exporting TensorRT-LLM checkpoints.") from modelopt.torch.export import export_tensorrt_llm_checkpoint # In TRT LLM, squared relu activation does not support bf16. So we use fp16 by default. export_tensorrt_llm_checkpoint( unwrapped_model[0], args.decoder, torch.bfloat16 if args.bf16 else torch.float16, export_dir=args.export_dir, inference_tensor_parallel=args.inference_tensor_parallel, inference_pipeline_parallel=1, use_nfs_workspace=True, ) print_rank_0(f"TensorRT-LLM checkpoints saved to {args.export_dir}")