# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import asyncio from argparse import Namespace from pydantic import PrivateAttr from megatron.core import parallel_state from megatron.core.inference.contexts.dynamic_context import DynamicInferenceContext from megatron.core.inference.coordinator import DynamicEngineCoordinator from megatron.core.inference.engines.abstract_engine import AbstractEngine from megatron.core.inference.engines.dynamic_engine import DynamicInferenceEngine from megatron.core.inference.engines.mcore_engine import MCoreEngine from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import ( GPTInferenceWrapper, ) from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import ( InferenceWrapperConfig, ) from megatron.core.inference.sampling_params import SamplingParams from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import ( SimpleTextGenerationController, ) from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.transformer.module import MegatronModule from megatron.training.global_vars import get_args, get_tokenizer from ..inference.inference_interface import ( ChatInferenceInterface, InferenceRequest, InferenceResponse, ReturnsRaw, ReturnsTokens, ) from ..server.api import InferenceServer ## This code is copied from tools/run_text_generation_server.py def get_static_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine: """Get the relevant backend for running inference. This function will automatically choose the TRTLLMBackend when possible, and default to Mcore backend if the user does not specify any backends. TRTLLMBackend is not implmented yet. Args: args (Namespace): The user arguments parsed from command line model (MegatronModule): The megatron model. Returns: AbstractBackend: The chosen backend """ tokenizer = get_tokenizer() inference_wrapper_config = InferenceWrapperConfig( hidden_size=args.hidden_size, inference_batch_times_seqlen_threshold=args.inference_batch_times_seqlen_threshold, fp32_residual_connection=args.fp32_residual_connection, params_dtype=args.params_dtype, padded_vocab_size=args.padded_vocab_size, inference_max_seq_length=args.inference_max_seq_length, inference_max_requests=( args.inference_max_batch_size if args.inference_max_batch_size is not None else 1 ), nccl_all_reduce_for_prefill=args.nccl_all_reduce_for_prefill, ) inference_wrapped_model = GPTInferenceWrapper(model, inference_wrapper_config) text_generation_controller = SimpleTextGenerationController( inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer ) return MCoreEngine( text_generation_controller=text_generation_controller, max_batch_size=( args.inference_max_batch_size if args.inference_max_batch_size is not None else 1 ), ) ## This code is copied from tools/run_text_generation_server.py def get_dynamic_inference_engine(args: Namespace, model: MegatronModule) -> AbstractEngine: """Get the relevant backend for running inference. This function will automatically choose the TRTLLMBackend when possible, and default to Mcore backend if the user does not specify any backends. TRTLLMBackend is not implmented yet. Args: args (Namespace): The user arguments parsed from command line model (MegatronModule): The megatron model. Returns: AbstractBackend: The chosen backend """ tokenizer = get_tokenizer() enable_cuda_graph = args.cuda_graph_impl == "local" num_cuda_graphs = None if enable_cuda_graph: num_cuda_graphs = args.inference_dynamic_batching_num_cuda_graphs # Inference context. inference_context = DynamicInferenceContext( params_dtype=args.params_dtype, num_layers=args.num_layers, kv_channels=args.kv_channels, num_attention_heads=( args.num_query_groups if args.group_query_attention else args.num_attention_heads ), max_sequence_length=args.inference_max_seq_length, num_cuda_graphs=num_cuda_graphs, buffer_size_gb=args.inference_dynamic_batching_buffer_size_gb, buffer_guaranteed_fraction=args.inference_dynamic_batching_buffer_guaranteed_fraction, block_size_tokens=args.inference_dynamic_batching_block_size, buffer_overflow_factor=args.inference_dynamic_batching_buffer_overflow_factor, max_requests_override=args.inference_dynamic_batching_max_requests_override, max_tokens_override=args.inference_dynamic_batching_max_tokens_override, tensor_model_parallel_size=args.tensor_model_parallel_size, materialize_only_last_token_logits=True, ) inference_wrapped_model = GPTInferenceWrapper(model, args, inference_context) inference_wrapped_model.model_is_pipeline_parallel = not ( parallel_state.is_pipeline_first_stage() and parallel_state.is_pipeline_last_stage() ) text_generation_controller = SimpleTextGenerationController( inference_wrapped_model=inference_wrapped_model, tokenizer=tokenizer ) return DynamicInferenceEngine( controller=text_generation_controller, context=inference_context, enable_cuda_graph=enable_cuda_graph, random_seed=args.seed, ) class MegatronLocal(InferenceServer, ReturnsTokens, ReturnsRaw): """Interface to use MCoreEngine directly as an inference engine.""" _coordinator: DynamicEngineCoordinator = PrivateAttr(None) _engine_task: asyncio.Task = PrivateAttr(None) _kill_engine: bool = PrivateAttr(False) async def base_generate(self, request: InferenceRequest): tokenizer = get_tokenizer() sampling_params = SamplingParams( num_tokens_to_generate=request.generation_args.max_tokens or 1024, temperature=request.generation_args.temperature or 1.0, top_k=request.generation_args.top_k or 0, top_p=request.generation_args.top_p or 0.0, termination_id=self._coordinator.engine.controller.tokenizer.eod, return_log_probs=True, skip_prompt_log_probs_for_dynamic_inference=True, add_BOS=tokenizer.bos is not None, ) request_ids = [ self._coordinator.schedule_request(prompt=prompt, sampling_params=sampling_params) for prompt in request.prompt ] responses = await asyncio.gather( *[self._coordinator.get_response(id) for id in request_ids] ) return [ InferenceResponse( response=r.generated_text, raw_text=p + r.generated_text, token_ids=r.prompt_tokens.tolist() + r.generated_tokens, logprobs=r.generated_log_probs, prompt_length=len(r.prompt_tokens), ) for p, r in zip(request.prompt, responses) ] @classmethod async def launch(cls, model: GPTModel, **kwargs): args = get_args() inference_engine: DynamicInferenceEngine = get_dynamic_inference_engine(args, model) coordinator = DynamicEngineCoordinator( inference_engine, inference_max_requests=inference_engine.context.max_requests, log_level=0, ) launched_server = cls(**kwargs) launched_server._coordinator = coordinator loop = asyncio.get_running_loop() coordinator.startup(loop) return launched_server async def kill(self): await self._coordinator.shutdown() async def suspend(self): await self._coordinator.suspend_engine() class MegatronChatLocal(ChatInferenceInterface, MegatronLocal): ...