"""A utility function for preparing model for inference
The function gets called once before the auto regressive inference loop. It puts the model in eval mode , and gets some model and inference data parameters. Extend this to build position ids ,attention mask etc, so that required slices can be extracted during the forward pass.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]
"""
self.model.eval()
# For TP only model both is_pp_first_stage and _is_pp_last_stage returns True
This function gets called iteratively in the inference loop . It can be used to extract relevant input from the prompt tokens, attention mask etc. required for each step in inference.
"""Utility to carry out forward pass for PP models with very small inputs
If a model is pipeline parallel, yet, the input global batch is very small, we compute a foward pass on the entire global batch, rather than splitting it up into micro batches and doing something more complex as in the forward_pass_with_pipeline_parallel_large_input_batch method
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
Runs the forward pass for models which are pipeline parallel. This is more complex than forward_pass_with_pipeline_parallel_small_input_batch coz this splits the global batch into small micro batches and runs them through the model.
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]
Appropriate utility is called for the forward pass depending on the type of model parallelism used
Args:
inference_input (List): A list containg the inputs for the gpt model [tokens, position ids, attention mask]
Returns:
torch.Tensor: The output logits of shape [batch_size, seq_len, padded_vocab_size]. The logits are returned only in the last pipeline stage for PP models.
"""
ifself.model_is_pipeline_parallel:
tokens=inference_input[0]
current_batch_size,seq_len=tokens.shape
# If input batch is large, we need to split into micro batches and run the forward pass
"""A utility function for preparing model for inference
This function is called before the forward pass. It puts the model in eval mode, builds position ids, and creates attention masks so that required slices can be extracted during the forward pass.
Args:
prompts_tokens (torch.Tensor): A tensor of shape [batch_size, max_seq_len]