"lm_eval/tasks/truthfulqa-multi/truthfulqa-multi_gen_common" did not exist on "a7ca04353fe1ff967f6c5b631bc31a10a6943b23"
mm_weight_calib.py 2.12 KB
Newer Older
helloyongyang's avatar
helloyongyang committed
1
2
3
4
5
6
import torch
from .mm_weight import MMWeight
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer


Dongz's avatar
Dongz committed
7
@MM_WEIGHT_REGISTER("Calib")
helloyongyang's avatar
helloyongyang committed
8
9
10
class MMWeightCalib(MMWeight):
    def __init__(self, weight_name, bias_name):
        super().__init__(weight_name, bias_name)
Dongz's avatar
Dongz committed
11

helloyongyang's avatar
helloyongyang committed
12
    def load(self, weight_dict):
Dongz's avatar
Dongz committed
13
        assert self.config and self.config.get("mm_type", "Default") != "Default"
helloyongyang's avatar
helloyongyang committed
14
15
16
17
        self.weight = weight_dict[self.weight_name]
        self.get_quantizer()
        shape_and_dtype = self.get_quant_shape_and_dtype(self.weight.shape)
        self.realq_weight, self.scales, self.zeros = self.w_quantizer.real_quant_tensor(self.weight)
Dongz's avatar
Dongz committed
18
19
        self.realq_weight = self.realq_weight.view(shape_and_dtype["tensor"][0]).contiguous().to(shape_and_dtype["tensor"][1])
        self.scales = self.scales.view(shape_and_dtype["scales"][0]).contiguous().to(shape_and_dtype["scales"][1])
helloyongyang's avatar
helloyongyang committed
20
        if self.zeros is not None:
Dongz's avatar
Dongz committed
21
            self.zeros = self.zeros.view(shape_and_dtype["zeros"][0]).contiguous().to(shape_and_dtype["zeros"][1])
helloyongyang's avatar
helloyongyang committed
22
23
24
25
26

    def apply(self, input_tensor):
        return super().apply(input_tensor)

    def get_quantizer(self):
Dongz's avatar
Dongz committed
27
28
29
        if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm":
            self.w_setting = {"bit": "e4m3", "symmetric": True, "granularity": "channel"}
            self.a_setting = {"bit": "e4m3", "symmetric": True, "granularity": "channel"}
helloyongyang's avatar
helloyongyang committed
30
31
32
33
            self.w_quantizer = FloatQuantizer(**self.w_setting)
            self.a_quantizer = FloatQuantizer(**self.a_setting)
            self.act_dynamic_quant = True
        else:
Dongz's avatar
Dongz committed
34
            raise NotImplementedError(f"Unsupported mm_type: {self.config['mm_type']}")
helloyongyang's avatar
helloyongyang committed
35
36

    def get_quant_shape_and_dtype(self, shape):
Dongz's avatar
Dongz committed
37
        if self.config["mm_type"] == "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm":
helloyongyang's avatar
helloyongyang committed
38
            return {
Dongz's avatar
Dongz committed
39
40
41
                "tensor": (shape, torch.float8_e5m2),
                "scales": ((shape[0], 1), torch.float32),
                "zeros": None,
helloyongyang's avatar
helloyongyang committed
42
43
            }
        else:
Dongz's avatar
Dongz committed
44
            raise NotImplementedError(f"Unsupported mm_type: {self.config['mm_type']}")