generate.py 5.15 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

"""Sample Generate GPT."""
import functools
import os
import sys
import warnings

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))

import torch
from datasets import load_dataset

from megatron.post_training.arguments import add_modelopt_args
from megatron.post_training.checkpointing import load_modelopt_checkpoint
from megatron.post_training.generate import simple_generate
from megatron.post_training.model_provider import model_provider
from megatron.post_training.utils import report_current_memory_info
from megatron.training import get_args, get_model, get_tokenizer, initialize_megatron
from megatron.training.utils import print_rank_0, unwrap_model

warnings.filterwarnings('ignore')


def add_generate_args(parser):
    """Add additional arguments for ModelOpt acceptance rate validation."""
    group = parser.add_argument_group(title='ModelOpt ar validation')
    group.add_argument("--osl", type=int, default=128, help="Output sequence length.")
    group.add_argument("--draft-length", type=int, default=0, help="Only used in EAGLE.")
    group.add_argument("--draft-topk", type=int, default=1, help="Only used in EAGLE.")
    group.add_argument("--disable-tqdm", action="store_true", help="Disable tqdm.")
    group.add_argument("--percentage", type=float, default=1.0)

    add_modelopt_args(parser)
    return parser


def check_arguments():
    """Checking user arguments."""
    args = get_args()
    if args.num_layers_per_virtual_pipeline_stage is not None:
        print_rank_0("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()

    if hasattr(args, 'moe_grouped_gemm') and args.moe_grouped_gemm == True:
        print_rank_0("WARNING: Forcing moe_grouped_gemm to False for PTQ and export.")
        args.moe_grouped_gemm = False


def mtbench_to_oai_chat(example):
    """Convert MTBench data to OpenAI chat completion format."""
    conversations = []
    for prompt in example["prompt"]:
        conversations.append({"role": "user", "content": prompt})
    example["conversations"] = conversations
    return example


def get_conversations(example):
    """Extract the input for tokenizer.apply_chat_template."""
    conversations = example.get("conversations", None)
    if conversations is None:
        conversations = example.get("messages", None)
    if conversations is None:
        raise ValueError(
            "The data must either have conversations or messages field, but got {}".format(example)
        )
    return conversations


if __name__ == "__main__":
    initialize_megatron(
        extra_args_provider=add_generate_args,
        args_defaults={
            'tokenizer_type': 'HuggingFaceTokenizer',
            'no_load_rng': True,
            'no_load_optim': True,
        },
    )

    check_arguments()

    args = get_args()

    default_conversations = [
        {
            "role": "user",
            "content": "Write an email to a wine expert, requesting a guest "
            "article contribution for your wine blog.",
        }
    ]

    if args.finetune_hf_dataset is None:
        if args.draft_length > 0:
            dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
            dataset = dataset.map(mtbench_to_oai_chat)
        else:
            dataset = [{"conversations": default_conversations}]
    else:
        dataset = load_dataset(args.finetune_hf_dataset, split=args.finetune_data_split)

    tokenizer = get_tokenizer()._tokenizer
    model = get_model(functools.partial(model_provider, parallel_output=True), wrap_with_ddp=False)

    report_current_memory_info()

    if args.load is not None:
        load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights)
        print_rank_0("Done loading checkpoint")

    unwrapped_model = unwrap_model(model)[0]
    unwrapped_model.eval()

    for idx, example in enumerate(dataset):
        if idx > args.percentage * len(dataset):
            break
        ref_conversations = get_conversations(example)
        new_conversations = []

        for message in ref_conversations:
            ground_truth = None
            if message["role"] == "assistant":
                ground_truth = message["content"]
            if message["role"] == "user":
                new_conversations.append(message)
                print_rank_0(
                    "{}".format(
                        tokenizer.apply_chat_template(
                            new_conversations, tokenize=False, add_generation_prompt=True
                        )
                    )
                )
                input_ids = tokenizer.apply_chat_template(
                    new_conversations, return_tensors="pt", add_generation_prompt=True
                )
                output_ids = simple_generate(
                    unwrapped_model, input_ids.cuda(), osl=args.osl, disable_tqdm=args.disable_tqdm
                )
                output_texts = tokenizer.batch_decode(output_ids)[0]
                print_rank_0("{}".format(output_texts))
                new_conversations.append({"role": "assistant", "content": output_texts})

    torch.distributed.barrier()