mm_weight_calib.py 2.26 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from .mm_weight import MMWeight
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer


@MM_WEIGHT_REGISTER('Calib')
class MMWeightCalib(MMWeight):
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
    
    def load(self, weight_dict):
        assert self.config and self.config.get('mm_type', 'Default') != 'Default'
        self.weight = weight_dict[self.weight_name]
        self.get_quantizer()
        shape_and_dtype = self.get_quant_shape_and_dtype(self.weight.shape)
        self.realq_weight, self.scales, self.zeros = self.w_quantizer.real_quant_tensor(self.weight)
        self.realq_weight = self.realq_weight.view(shape_and_dtype['tensor'][0]).contiguous().to(shape_and_dtype['tensor'][1])
        self.scales = self.scales.view(shape_and_dtype['scales'][0]).contiguous().to(shape_and_dtype['scales'][1])
        if self.zeros is not None:
            self.zeros = self.zeros.view(shape_and_dtype['zeros'][0]).contiguous().to(shape_and_dtype['zeros'][1])
        

    def apply(self, input_tensor):
        return super().apply(input_tensor)

    def get_quantizer(self):
        if self.config['mm_type'] == 'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm':
            self.w_setting = {
                'bit': 'e4m3',
                'symmetric': True,
                'granularity': 'channel'
            }
            self.a_setting = {
                'bit': 'e4m3',
                'symmetric': True,
                'granularity': 'channel'
            }
            self.w_quantizer = FloatQuantizer(**self.w_setting)
            self.a_quantizer = FloatQuantizer(**self.a_setting)
            self.act_dynamic_quant = True
        else:
            raise NotImplementedError(f'Unsupported mm_type: {self.config["mm_type"]}')

    def get_quant_shape_and_dtype(self, shape):
        if self.config['mm_type'] == 'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm':
            return {
                'tensor': (shape, torch.float8_e5m2),
                'scales': ((shape[0], 1), torch.float32),
                'zeros': None,
            }
        else:
            raise NotImplementedError(f'Unsupported mm_type: {self.config["mm_type"]}')