run_text_generation_server.py 5.43 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
Ryan Prenger's avatar
Ryan Prenger committed
2
3
4
5
6
7

"""Sample Generate GPT"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
                                             os.path.pardir)))
xingjinliang's avatar
xingjinliang committed
8
9
from megatron.training import get_args
from megatron.training import print_rank_0
10
from megatron.core import mpu
xingjinliang's avatar
xingjinliang committed
11
12
13
from megatron.training.checkpointing import load_checkpoint
from megatron.training.initialize import initialize_megatron
from megatron.core.models.gpt import GPTModel
Ryan Prenger's avatar
Ryan Prenger committed
14
from megatron.training import get_model
xingjinliang's avatar
xingjinliang committed
15
16
17
18
19
20
21
22
23
24
25
26
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from megatron.inference.text_generation_server import MegatronServer
from megatron.inference.text_generation import generate_and_post_process
from megatron.inference.text_generation import beam_search_and_post_process
from megatron.core.transformer.spec_utils import import_module
from megatron.core.models.gpt.gpt_layer_specs import (
    get_gpt_layer_local_spec,
    get_gpt_layer_with_transformer_engine_spec,
)

from contextlib import nullcontext
Ryan Prenger's avatar
Ryan Prenger committed
27
import torch
xingjinliang's avatar
xingjinliang committed
28
29
30
31
32
33
34
35
from typing import Union
import megatron


def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
    """Builds the model.

        If you set the use_legacy_models to True, it will return the legacy GPT model and if not the core GPT model.
Ryan Prenger's avatar
Ryan Prenger committed
36

xingjinliang's avatar
xingjinliang committed
37
38
39
        Args:
            pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
            post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
Ryan Prenger's avatar
Ryan Prenger committed
40

xingjinliang's avatar
xingjinliang committed
41
42
43
44
45
46
47

        Returns:
            Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model
        """

    args = get_args()
    use_te = args.transformer_impl == "transformer_engine"
liangjing's avatar
v1  
liangjing committed
48

Ryan Prenger's avatar
Ryan Prenger committed
49
    print_rank_0('building GPT model ...')
xingjinliang's avatar
xingjinliang committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    # Experimental loading arguments from yaml
    if args.yaml_cfg is not None:
        config = core_transformer_config_from_yaml(args, "language_model")
    else:
        config = core_transformer_config_from_args(args)

    if args.use_legacy_models:
        model = megatron.legacy.model.GPTModel(
            config,
            num_tokentypes=0,
            parallel_output=False,
            pre_process=pre_process,
            post_process=post_process
        )
    else:
        if args.spec is not None:
            transformer_layer_spec = import_module(args.spec)
        else:
            if use_te:
                transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)
            else:
                transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)

        model = GPTModel(
            config=config,
            transformer_layer_spec=transformer_layer_spec,
            vocab_size=args.padded_vocab_size,
            max_sequence_length=args.max_position_embeddings,
            pre_process=pre_process,
            post_process=post_process,
            fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
            parallel_output=False,
            share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
            position_embedding_type=args.position_embedding_type,
            rotary_percent=args.rotary_percent,
            rotary_base=args.rotary_base,
            rope_scaling=args.use_rope_scaling
        )
Ryan Prenger's avatar
Ryan Prenger committed
89
90
91
92
93

    return model

def add_text_generate_args(parser):
    group = parser.add_argument_group(title='text generation')
liangjing's avatar
v1  
liangjing committed
94
95
    group.add_argument("--port", type=int, default=5000,
                       help='port for text generation server to run on')
Ryan Prenger's avatar
Ryan Prenger committed
96
97
98
99
100
101
102
103
104
105
106
107
108
    return parser


if __name__ == "__main__":
    initialize_megatron(extra_args_provider=add_text_generate_args,
                        args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
                                       'no_load_rng': True,
                                       'no_load_optim': True})

    args = get_args()
    if args.num_layers_per_virtual_pipeline_stage is not None:
        print("Interleaved pipeline schedule is not yet supported for text generation.")
        exit()
liangjing's avatar
v1  
liangjing committed
109
110
111
    print_rank_0("WARNING: Forcing exit_on_missing_checkpoint to True for text "
                 "generation.")
    args.exit_on_missing_checkpoint = True
xingjinliang's avatar
xingjinliang committed
112

Ryan Prenger's avatar
Ryan Prenger committed
113
    # Set up model and load checkpoint
xingjinliang's avatar
xingjinliang committed
114
115
116
117
118
119
    load_context = nullcontext()
    if args.fp8:
        from transformer_engine.pytorch.fp8 import fp8_model_init
        load_context = fp8_model_init()
    with load_context:
        model = get_model(model_provider, wrap_with_ddp=False)
Ryan Prenger's avatar
Ryan Prenger committed
120
121
122
123
124
125

    if args.load is not None:
        _ = load_checkpoint(model, None, None)

    assert len(model) == 1, "Above condition should have caught this"
    model = model[0]
xingjinliang's avatar
xingjinliang committed
126
127
    model.eval()

Ryan Prenger's avatar
Ryan Prenger committed
128
129
    if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
        server = MegatronServer(model)
liangjing's avatar
v1  
liangjing committed
130
        server.run("0.0.0.0",port=args.port)
Ryan Prenger's avatar
Ryan Prenger committed
131
132

    while True:
xingjinliang's avatar
xingjinliang committed
133
        choice = torch.tensor(1, dtype=torch.long, device='cuda')
134
        torch.distributed.broadcast(choice, 0)
xingjinliang's avatar
xingjinliang committed
135
        if choice.item() == 0:
136
137
138
139
            try:
                generate_and_post_process(model)
            except ValueError as ve:
                pass
xingjinliang's avatar
xingjinliang committed
140
        elif choice.item() == 1:
rprenger's avatar
rprenger committed
141
142
143
144
            try:
                beam_search_and_post_process(model)
            except ValueError as ve:
                pass