# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """Forward step utilities.""" from collections.abc import Iterable import torch from megatron.training import get_args from megatron.core import mpu, InferenceParams from .communication import ( send_to_next_pipeline_rank, recv_from_prev_pipeline_rank_) class ForwardStep: """Forward step function with all the communications. We use a class here to hide the inference parameters from the outside caller.""" def __init__(self, model, max_batch_size, max_sequence_length): """Set values so we don't need to do it multiple times.""" # Make sure model is in eval mode. assert not isinstance(model, Iterable), \ 'interleaving schedule is not supported for inference' model.eval() self.model = model # Initialize inference parameters. self.inference_params = InferenceParams(max_batch_size, max_sequence_length) # Pipelining arguments. args = get_args() self.pipeline_size_larger_than_one = ( args.pipeline_model_parallel_size > 1) # Threshold for whether we split up the batch for pipelining. self.pipelining_batch_x_seqlen = \ args.inference_batch_times_seqlen_threshold def _forward(self, tokens, position_ids, attention_mask): return self.model(tokens, position_ids, attention_mask, inference_params=self.inference_params) def __call__(self, tokens, position_ids, attention_mask, recv_buffer_seq_length=None): """Invocation of the forward methods. Note that self.inference_params is being modified by the forward step.""" # Pipelining case. # This runs only if current_batch_x_seqlen > args.inference_batch_times_seqlen_threshold # and requires setting args.pipeline_model_parallel > 1. The batch will be split into # smaller microbatches to be pipelined through the stages. if self.pipeline_size_larger_than_one: seq_len = tokens.size(1) if recv_buffer_seq_length is None else recv_buffer_seq_length current_batch_x_seqlen = tokens.size(0) * seq_len if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: micro_batch_size = \ max(1, self.pipelining_batch_x_seqlen // seq_len) return self._with_pipelining_forward_step(tokens, position_ids, attention_mask, micro_batch_size, recv_buffer_seq_length=recv_buffer_seq_length) recv_buffer = None if recv_buffer_seq_length is not None: recv_buffer = _allocate_recv_buffer(tokens.size(0), recv_buffer_seq_length) return self._no_pipelining_forward_step(tokens, position_ids, attention_mask, recv_buffer=recv_buffer) def _forward_step_helper(self, tokens, position_ids, attention_mask, recv_buffer=None): """Single forward step. Update the allocate memory flag so only the first time the memory is allocated.""" batch_size = tokens.size(0) sequence_length = tokens.size(1) if recv_buffer is None: recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) # Receive from previous stage. if recv_buffer is not None and torch.numel(recv_buffer) > 0: recv_from_prev_pipeline_rank_(recv_buffer) # Forward pass through the model. if not mpu.is_pipeline_first_stage(): self.model.set_input_tensor(recv_buffer) output_tensor = self._forward(tokens, position_ids, attention_mask) if isinstance(output_tensor, tuple): output_tensor = output_tensor[0] # Send output to the next stage. send_to_next_pipeline_rank(output_tensor) return output_tensor def _no_pipelining_forward_step(self, tokens, position_ids, attention_mask, recv_buffer=None): """If recv_buffer is none, we will allocate one on the fly.""" # Run a simple forward pass. output_tensor = self._forward_step_helper(tokens, position_ids, attention_mask, recv_buffer=recv_buffer) # Update the sequence length offset. self.inference_params.sequence_len_offset += tokens.size(1) logits = None if mpu.is_pipeline_last_stage(): logits = output_tensor return logits def _with_pipelining_forward_step(self, tokens, position_ids, attention_mask, micro_batch_size, recv_buffer_seq_length=None): """No interleaving is supported.""" batch_size = tokens.size(0) sequence_length = tokens.size(1) if recv_buffer_seq_length is None else recv_buffer_seq_length # Divide the batch dimension into micro batches. num_micro_batches, last_chunk = divmod(batch_size, micro_batch_size) if last_chunk > 0: num_micro_batches += 1 # Preallocate memory for output logits. logits = None if mpu.is_pipeline_last_stage(): args = get_args() logits = torch.empty( (batch_size, sequence_length, args.padded_vocab_size), dtype=torch.float32, device=torch.cuda.current_device()) # Preallocate recv buffer. recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length) for micro_batch_index in range(num_micro_batches): # Slice among the batch dimenion. start = micro_batch_index * micro_batch_size end = min(start + micro_batch_size, batch_size) this_micro_batch_size = end - start tokens2use = tokens[start:end, ...] position_ids2use = position_ids[start:end, ...] # Run a simple forward pass. if this_micro_batch_size != micro_batch_size: recv_buffer = None output = self._forward_step_helper(tokens2use, position_ids2use, attention_mask, recv_buffer=recv_buffer) # Adjust the batch size offset to account for the micro-batch. self.inference_params.batch_size_offset += this_micro_batch_size # Copy logits. if mpu.is_pipeline_last_stage(): logits[start:end, ...] = output # Once we are done with all the micro-batches, we can # adjust the sequence length offset. self.inference_params.sequence_len_offset += tokens.size(1) # and reset the batch size offset self.inference_params.batch_size_offset = 0 return logits def _get_recv_buffer_dtype(args): """Receive happens between the layers.""" if args.fp32_residual_connection: return torch.float return args.params_dtype def _allocate_recv_buffer(batch_size, sequence_length): """Receive happens between the layers with size [s, b, h].""" if mpu.is_pipeline_first_stage(): return None args = get_args() recv_size = (sequence_length, batch_size, args.hidden_size) return torch.empty(recv_size, dtype=_get_recv_buffer_dtype(args), device=torch.cuda.current_device())