# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. class InferenceParams: """Inference parameters that are passed to the main model in order to efficienly calculate and store the context during inference.""" def __init__(self, max_batch_size, max_sequence_length): self.max_sequence_length = max_sequence_length self.max_batch_size = max_batch_size self.sequence_len_offset = 0 self.batch_size_offset = 0 self.key_value_memory_dict = {} def swap_key_value_dict(self, batch_idx): "swap between batches" if len(self.key_value_memory_dict) == 0: raise ValueError("should not swap when dict in empty") for layer_number in self.key_value_memory_dict.keys(): inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] assert ( len(batch_idx) == inference_key_memory.shape[1] ) # make sure batch size is the same new_inference_key_memory = inference_key_memory[:, batch_idx] new_inference_value_memory = inference_value_memory[:, batch_idx] self.key_value_memory_dict[layer_number] = ( new_inference_key_memory, new_inference_value_memory, ) def __str__(self): return f"InferenceParams(max_seq_len = {self.max_sequence_length}, max_batch_size = {self.max_batch_size}, sequence_len_offset = {self.sequence_len_offset}, batch_size_offset = {self.batch_size_offset}, key_value_memory_dict = {self.key_value_memory_dict.keys()})"