inference_params.py 3.95 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
liangjing's avatar
v1  
liangjing committed
2
3
4
5
6
7
8
class InferenceParams:
    """Inference parameters that are passed to the main model in order
    to efficienly calculate and store the context during inference."""

    def __init__(self, max_batch_size, max_sequence_length):
        self.max_sequence_length = max_sequence_length
        self.max_batch_size = max_batch_size
wangxj's avatar
wangxj committed
9
        self.current_batch_size = max_batch_size  # Required for bookkeeping variable-sized batches
liangjing's avatar
v1  
liangjing committed
10
11
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
wangxj's avatar
wangxj committed
12
        self.decode_mode = False
liangjing's avatar
v1  
liangjing committed
13
        self.key_value_memory_dict = {}
wangxj's avatar
wangxj committed
14
        self.decode_mode = False
liangjing's avatar
v1  
liangjing committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31

    def swap_key_value_dict(self, batch_idx):
        "swap between batches"
        if len(self.key_value_memory_dict) == 0:
            raise ValueError("should not swap when dict in empty")

        for layer_number in self.key_value_memory_dict.keys():
            inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
            assert (
                len(batch_idx) == inference_key_memory.shape[1]
            )  # make sure batch size is the same
            new_inference_key_memory = inference_key_memory[:, batch_idx]
            new_inference_value_memory = inference_value_memory[:, batch_idx]
            self.key_value_memory_dict[layer_number] = (
                new_inference_key_memory,
                new_inference_value_memory,
            )
xingjinliang's avatar
xingjinliang committed
32

wangxj's avatar
wangxj committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    def enable_prefill_mode(self):
        """
        Indicates the generation loop is in the prefill phase (still processing
        input prompt tokens). This should be enabled if the generation loop is
        encoding prompt tokens for *any* request in a batch.
        """
        self.decode_mode = False

    def enable_decode_mode(self):
        """
        Indicates the generation loop is in the decode phase (generating new output
        tokens). This should only be enabled if the generation loop has fully encoded
        the prompts for *all* requests in a batch.
        """
        self.decode_mode = True

    def reset(self):
        """Resets the inference state for a new batch."""
        self.current_batch_size = self.max_batch_size
        self.sequence_len_offset = 0
        self.batch_size_offset = 0
        self.enable_prefill_mode()

xingjinliang's avatar
xingjinliang committed
56
    def __str__(self):
wangxj's avatar
wangxj committed
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        return (
            f"InferenceParams(max_seq_len = {self.max_sequence_length}, "
            f"max_batch_size = {self.max_batch_size}, "
            f"current_batch_size = {self.current_batch_size}, "
            f"sequence_len_offset = {self.sequence_len_offset}, "
            f"batch_size_offset = {self.batch_size_offset}, "
            f"key_value_memory_dict = {self.key_value_memory_dict.keys()})"
            f"decode_mode = {self.decode_mode}"
        )

    def __eq__(self, other):

        if not isinstance(other, InferenceParams):
            return False

        # Check all attributes match
        basic_attrs = [
            'max_sequence_length',
            'max_batch_size',
            'current_batch_size',
            'sequence_len_offset',
            'batch_size_offset',
        ]

        if not all(hasattr(other, attr) for attr in basic_attrs):
            return False

        # Check dictionary keys match; i.e. the same number of layers are cached
        if self.key_value_memory_dict.keys() != other.key_value_memory_dict.keys():
            return False

        # Check each tensor tuple in the dictionary
        for key in self.key_value_memory_dict:
            self_tensors = self.key_value_memory_dict[key]
            other_tensors = other.key_value_memory_dict[key]

            # Compare each key, value tensor in the tuple
            for self_tensor, other_tensor in zip(self_tensors, other_tensors):
                if (
                    self_tensor.data_ptr() != other_tensor.data_ptr()
                    or self_tensor.shape != other_tensor.shape
                ):
                    return False
        return True