megatron.py 7.76 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

import asyncio
from argparse import Namespace

from pydantic import PrivateAttr

from megatron.core import parallel_state
from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext
from megatron.core.inference.coordinator import DynamicEngineCoordinator
from megatron.core.inference.engines.abstract_engine import AbstractEngine
from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import (
    GPTInferenceWrapper,
)
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import (
    InferenceWrapperConfig,
)
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
    SimpleTextGenerationController,
)
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.transformer.module import MegatronModule
from megatron.training.global_vars import get_args, get_tokenizer

from ..inference.inference_interface import (
    ChatInferenceInterface,
    InferenceRequest,
    InferenceResponse,
    ReturnsRaw,
    ReturnsTokens,
)
from ..server.api import InferenceServer


## This code is copied from tools/run_text_generation_server.py
def get_static_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine:
    """Get the relevant backend for running inference.

    This function will automatically choose the TRTLLMBackend when possible, and default to Mcore backend if the user does not specify any backends. TRTLLMBackend is not implmented yet.

    Args:
        args (Namespace): The user arguments parsed from command line
        model (MegatronModule): The megatron model.

    Returns:
        AbstractBackend: The chosen backend
    """
    tokenizer = get_tokenizer()

    inference_wrapper_config = InferenceWrapperConfig(
        hidden_size=args.hidden_size,
        inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold,
        fp32_residual_connection=args.fp32_residual_connection,
        params_dtype=args.params_dtype,
        padded_vocab_size=args.padded_vocab_size,
        inference_max_seq_length=args.inference_max_seq_length,
        inference_max_requests=(
            args.inference_max_batch_size if args.inference_max_batch_size is not None else 1
        ),
        nccl_all_reduce_for_prefill=args.nccl_all_reduce_for_prefill,
    )

    inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config)
    text_generation_controller = SimpleTextGenerationController(
        inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer
    )
    return MCoreEngine(
        text_generation_controller=text_generation_controller,
        max_batch_size=(
            args.inference_max_batch_size if args.inference_max_batch_size is not None else 1
        ),
    )


## This code is copied from tools/run_text_generation_server.py
def get_dynamic_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine:
    """Get the relevant backend for running inference.

    This function will automatically choose the TRTLLMBackend when possible, and default to Mcore backend if the user does not specify any backends. TRTLLMBackend is not implmented yet.

    Args:
        args (Namespace): The user arguments parsed from command line
        model (MegatronModule): The megatron model.

    Returns:
        AbstractBackend: The chosen backend
    """
    tokenizer = get_tokenizer()

    enable_cuda_graph = args.cuda_graph_impl == "local"
    num_cuda_graphs = None
    if enable_cuda_graph:
        num_cuda_graphs = args.inference_dynamic_batching_num_cuda_graphs

    # Inference context.
    inference_context = DynamicInferenceContext(
        params_dtype=args.params_dtype,
        num_layers=args.num_layers,
        kv_channels=args.kv_channels,
        num_attention_heads=(
            args.num_query_groups if args.group_query_attention else args.num_attention_heads
        ),
        max_sequence_length=args.inference_max_seq_length,
        num_cuda_graphs=num_cuda_graphs,
        buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb,
        buffer_guaranteed_fraction=args.inference_dynamic_batching_buffer_guaranteed_fraction,
        block_size_tokens=args.inference_dynamic_batching_block_size,
        buffer_overflow_factor=args.inference_dynamic_batching_buffer_overflow_factor,
        max_requests_override=args.inference_dynamic_batching_max_requests_override,
        max_tokens_override=args.inference_dynamic_batching_max_tokens_override,
        tensor_model_parallel_size=args.tensor_model_parallel_size,
        materialize_only_last_token_logits=True,
    )

    inference_wrapped_model = GPTInferenceWrapper(model, args, inference_context)

    inference_wrapped_model.model_is_pipeline_parallel = not (
        parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage()
    )

    text_generation_controller = SimpleTextGenerationController(
        inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer
    )

    return DynamicInferenceEngine(
        controller=text_generation_controller,
        context=inference_context,
        enable_cuda_graph=enable_cuda_graph,
        random_seed=args.seed,
    )


class MegatronLocal(InferenceServer, ReturnsTokens, ReturnsRaw):
    """Interface to use MCoreEngine directly as an inference engine."""

    _coordinator: DynamicEngineCoordinator = PrivateAttr(None)
    _engine_task: asyncio.Task = PrivateAttr(None)
    _kill_engine: bool = PrivateAttr(False)

    async def base_generate(self, request: InferenceRequest):

        tokenizer = get_tokenizer()

        sampling_params = SamplingParams(
            num_tokens_to_generate=request.generation_args.max_tokens or 1024,
            temperature=request.generation_args.temperature or 1.0,
            top_k=request.generation_args.top_k or 0,
            top_p=request.generation_args.top_p or 0.0,
            termination_id=self._coordinator.engine.controller.tokenizer.eod,
            return_log_probs=True,
            skip_prompt_log_probs_for_dynamic_inference=True,
            add_BOS=tokenizer.bos is not None,
        )
        request_ids = [
            self._coordinator.schedule_request(prompt=prompt, sampling_params=sampling_params)
            for prompt in request.prompt
        ]
        responses = await asyncio.gather(
            *[self._coordinator.get_response(id) for id in request_ids]
        )
        return [
            InferenceResponse(
                response=r.generated_text,
                raw_text=p + r.generated_text,
                token_ids=r.prompt_tokens.tolist() + r.generated_tokens,
                logprobs=r.generated_log_probs,
                prompt_length=len(r.prompt_tokens),
            )
            for p, r in zip(request.prompt, responses)
        ]

    @classmethod
    async def launch(cls, model: GPTModel, **kwargs):
        args = get_args()

        inference_engine: DynamicInferenceEngine = get_dynamic_inference_engine(args, model)
        coordinator = DynamicEngineCoordinator(
            inference_engine,
            inference_max_requests=inference_engine.context.max_requests,
            log_level=0,
        )
        launched_server = cls(**kwargs)
        launched_server._coordinator = coordinator

        loop = asyncio.get_running_loop()

        coordinator.startup(loop)

        return launched_server

    async def kill(self):
        await self._coordinator.shutdown()

    async def suspend(self):
        await self._coordinator.suspend_engine()


class MegatronChatLocal(ChatInferenceInterface, MegatronLocal): ...