mm_weight_calib.py 2.13 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
import torch
PengGao's avatar
PengGao committed
2
3

from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
helloyongyang's avatar
helloyongyang committed
4
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
PengGao's avatar
PengGao committed
5
6

from .mm_weight import MMWeight
helloyongyang's avatar
helloyongyang committed
7
8


Dongz's avatar
Dongz committed
9
@MM_WEIGHT_REGISTER("Calib")
helloyongyang's avatar
helloyongyang committed
10
11
12
class MMWeightCalib(MMWeight):
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
Dongz's avatar
Dongz committed
13

helloyongyang's avatar
helloyongyang committed
14
    def load(self, weight_dict):
Dongz's avatar
Dongz committed
15
        assert self.config and self.config.get("mm_type", "Default") != "Default"
helloyongyang's avatar
helloyongyang committed
16
17
18
19
        self.weight = weight_dict[self.weight_name]
        self.get_quantizer()
        shape_and_dtype = self.get_quant_shape_and_dtype(self.weight.shape)
        self.realq_weight, self.scales, self.zeros = self.w_quantizer.real_quant_tensor(self.weight)
Dongz's avatar
Dongz committed
20
21
        self.realq_weight = self.realq_weight.view(shape_and_dtype["tensor"][0]).contiguous().to(shape_and_dtype["tensor"][1])
        self.scales = self.scales.view(shape_and_dtype["scales"][0]).contiguous().to(shape_and_dtype["scales"][1])
helloyongyang's avatar
helloyongyang committed
22
        if self.zeros is not None:
Dongz's avatar
Dongz committed
23
            self.zeros = self.zeros.view(shape_and_dtype["zeros"][0]).contiguous().to(shape_and_dtype["zeros"][1])
helloyongyang's avatar
helloyongyang committed
24
25
26
27
28

    def apply(self, input_tensor):
        return super().apply(input_tensor)

    def get_quantizer(self):
Dongz's avatar
Dongz committed
29
        if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm":
gushiqiao's avatar
gushiqiao committed
30
31
            self.w_setting = {"bit": "e4m3", "symmetric": True, "granularity": "per_channel"}
            self.a_setting = {"bit": "e4m3", "symmetric": True, "granularity": "per_channel"}
helloyongyang's avatar
helloyongyang committed
32
33
34
35
            self.w_quantizer = FloatQuantizer(**self.w_setting)
            self.a_quantizer = FloatQuantizer(**self.a_setting)
            self.act_dynamic_quant = True
        else:
Dongz's avatar
Dongz committed
36
            raise NotImplementedError(f"Unsupported mm_type: {self.config['mm_type']}")
helloyongyang's avatar
helloyongyang committed
37
38

    def get_quant_shape_and_dtype(self, shape):
Dongz's avatar
Dongz committed
39
        if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm":
helloyongyang's avatar
helloyongyang committed
40
            return {
Dongz's avatar
Dongz committed
41
42
43
                "tensor": (shape, torch.float8_e5m2),
                "scales": ((shape[0], 1), torch.float32),
                "zeros": None,
helloyongyang's avatar
helloyongyang committed
44
45
            }
        else:
Dongz's avatar
Dongz committed
46
            raise NotImplementedError(f"Unsupported mm_type: {self.config['mm_type']}")