observer.py 3.77 KB
Newer Older
pppppM's avatar
pppppM committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# Copyright (c) OpenMMLab. All rights reserved.

import torch

from lmdeploy.lite.utils.global_avail import GlobalAvailMixin


class KVCacheObserver(GlobalAvailMixin):
    """A class to observe and record the max, min, and absolute max value of
    given tensor."""

    def __init__(self, num_head: int, head_dim: int) -> None:
        """Constructor for KVCacheObserver.

        Args:
            num_head : Number of heads
            head_dim : Dimension of each head
        """
        self.num_head = num_head
        self.head_dim = head_dim
        self.max_val = torch.full((num_head, head_dim),
                                  -torch.inf,
                                  dtype=torch.float16)
        self.min_val = torch.full((num_head, head_dim),
                                  torch.inf,
                                  dtype=torch.float16)
        self.absmax_val = torch.full((num_head, head_dim),
                                     0,
                                     dtype=torch.float16)

    @torch.no_grad()
    def observe(self, x: torch.Tensor) -> None:
        """Function to observe the input tensor and update the max, min, and
        absolute max values.

        Args:
            x : Input tensor
        """
        assert len(x.shape) == 4
        x = x.transpose(1, 2)
        assert x.size(2) == self.num_head
        assert x.size(3) == self.head_dim

        cur_max = x.flatten(0, 1).max(0)[0].cpu()
        cur_min = x.flatten(0, 1).min(0)[0].cpu()
        cur_absmax = x.flatten(0, 1).abs().max(0)[0].cpu()

        self.max_val = torch.maximum(self.max_val, cur_max)
        self.min_val = torch.minimum(self.min_val, cur_min)
        self.absmax_val = torch.maximum(self.absmax_val, cur_absmax)


class ActivationObserver(GlobalAvailMixin):
    """A class to observe and record the max, min, mean, absolute max, and
    absolute mean value of a given tensor.

    Also keeps track of the number of batches observed.
    """

    def __init__(self, dim: int) -> None:
        """Constructor for ActivationObserver.

        Args:
            dim : Dimension of the tensor
        """
        self.dim = dim
        self.max_val = torch.full((dim, ), -torch.inf, dtype=torch.float16)
        self.min_val = torch.full((dim, ), torch.inf, dtype=torch.float16)
        self.absmax_val = torch.full((dim, ), 0, dtype=torch.float16)
        self.absmean_val = torch.full((dim, ), 0, dtype=torch.float16)
        self.mean_val = torch.full((dim, ), 0, dtype=torch.float16)
        self.num_batches_tracked = 0

    @torch.no_grad()
    def observe(self, x: torch.Tensor) -> None:
        """Function to observe the input tensor and update the max, min, mean,
        absolute max, absolute mean values and number of batches tracked.

        Args:
            x : Input tensor
        """
        assert len(x.shape) == 3
        assert x.size(2) == self.dim
        cur_val = x.flatten(0, 1)
        cur_max = cur_val.max(0)[0].cpu()
        cur_min = cur_val.min(0)[0].cpu()
        cur_mean = cur_val.mean(0).cpu()

        cur_abs = cur_val.abs()
        cur_absmax = cur_abs.max(0)[0].cpu()
        cur_absmean = cur_abs.mean(0).cpu()

        self.max_val = torch.maximum(self.max_val, cur_max)
        self.min_val = torch.minimum(self.min_val, cur_min)
        self.absmax_val = torch.maximum(self.absmax_val, cur_absmax)

        # Update mean and absmean value with accumulated sum divided
        # by total number of batches
        self.mean_val = (
            (self.mean_val * self.num_batches_tracked + cur_mean) /
            (self.num_batches_tracked + 1))
        self.absmean_val = (
            (self.absmean_val * self.num_batches_tracked + cur_absmean) /
            (self.num_batches_tracked + 1))

        # Increment the count of batches tracked
        self.num_batches_tracked += 1