quant.py 4.5 KB
Newer Older
1
2
3
4
5
from abc import ABCMeta

import torch
from qtorch.quant import float_quantize

6
7
from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
try:
    from lightx2v_kernel.gemm import scaled_mxfp4_quant, scaled_mxfp6_quant, scaled_mxfp8_quant, scaled_nvfp4_quant
except ImportError:
    pass


class QuantTemplate(metaclass=ABCMeta):
    def __init__(self, weight):
        if weight.dim() != 2:
            raise ValueError(f"Only 2D tensors supported. Got {weight.dim()}D tensor")
        if torch.isnan(weight).any():
            raise ValueError("Tensor contains NaN values")

        self.weight_quant_func = None
        self.extra = {}


Bilang ZHANG's avatar
Bilang ZHANG committed
25
@CONVERT_WEIGHT_REGISTER("int8")
26
27
28
29
30
31
class QuantWeightINT8(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_int8_weight

    @torch.no_grad()
Bilang ZHANG's avatar
Bilang ZHANG committed
32
    def load_int8_weight(self, w, comfyui_mode=False):
33
        org_w_shape = w.shape
Bilang ZHANG's avatar
Bilang ZHANG committed
34
35
36
37
        if not comfyui_mode:
            max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
        else:
            max_val = w.abs().max()
38
39
40
41
42
43
44
        qmin, qmax = -128, 127
        scales = max_val / qmax
        w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w_q).sum() == 0

Bilang ZHANG's avatar
Bilang ZHANG committed
45
46
47
        if not comfyui_mode:
            scales = scales.view(org_w_shape[0], -1)
            w_q = w_q.reshape(org_w_shape)
48
49
50
51

        return w_q, scales, self.extra


Bilang ZHANG's avatar
Bilang ZHANG committed
52
@CONVERT_WEIGHT_REGISTER("fp8")
53
54
55
56
57
58
class QuantWeightFP8(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_fp8_weight

    @torch.no_grad()
Bilang ZHANG's avatar
Bilang ZHANG committed
59
    def load_fp8_weight(self, w, comfyui_mode=False):
60
        org_w_shape = w.shape
Bilang ZHANG's avatar
Bilang ZHANG committed
61
62
63
64
        if not comfyui_mode:
            max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
        else:
            max_val = w.abs().max()
65
66
67
68
69
70
71
72
73
74
        finfo = torch.finfo(torch.float8_e4m3fn)
        qmin, qmax = finfo.min, finfo.max
        scales = max_val / qmax
        scaled_tensor = w / scales
        scaled_tensor = torch.clip(scaled_tensor, qmin, qmax)
        w_q = float_quantize(scaled_tensor.float(), 4, 3, rounding="nearest").to(torch.float8_e4m3fn)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w_q).sum() == 0

Bilang ZHANG's avatar
Bilang ZHANG committed
75
76
77
        if not comfyui_mode:
            scales = scales.view(org_w_shape[0], -1)
            w_q = w_q.reshape(org_w_shape)
78
79
80
81

        return w_q, scales, self.extra


Bilang ZHANG's avatar
Bilang ZHANG committed
82
@CONVERT_WEIGHT_REGISTER("mxfp4")
83
84
85
86
87
88
class QuantWeightMxFP4(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_mxfp4_weight

    @torch.no_grad()
Bilang ZHANG's avatar
Bilang ZHANG committed
89
    def load_mxfp4_weight(self, w, comfyui_mode=False):
90
91
92
93
94
95
96
        device = w.device
        w = w.cuda().to(torch.bfloat16)
        w_q, scales = scaled_mxfp4_quant(w)
        w_q, scales = w_q.to(device), scales.to(device)
        return w_q, scales, self.extra


Bilang ZHANG's avatar
Bilang ZHANG committed
97
@CONVERT_WEIGHT_REGISTER("mxfp6")
98
99
100
101
102
103
class QuantWeightMxFP6(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_mxfp6_weight

    @torch.no_grad()
Bilang ZHANG's avatar
Bilang ZHANG committed
104
    def load_mxfp6_weight(self, w, comfyui_mode=False):
105
106
107
108
109
110
111
        device = w.device
        w = w.cuda().to(torch.bfloat16)
        w_q, scales = scaled_mxfp6_quant(w)
        w_q, scales = w_q.to(device), scales.to(device)
        return w_q, scales, self.extra


Bilang ZHANG's avatar
Bilang ZHANG committed
112
@CONVERT_WEIGHT_REGISTER("mxfp8")
113
114
115
116
117
118
class QuantWeightMxFP8(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_mxfp8_weight

    @torch.no_grad()
Bilang ZHANG's avatar
Bilang ZHANG committed
119
    def load_mxfp8_weight(self, w, comfyui_mode=False):
120
121
122
123
124
125
126
        device = w.device
        w = w.cuda().to(torch.bfloat16)
        w_q, scales = scaled_mxfp8_quant(w)
        w_q, scales = w_q.to(device), scales.to(device)
        return w_q, scales, self.extra


Bilang ZHANG's avatar
Bilang ZHANG committed
127
@CONVERT_WEIGHT_REGISTER("nvfp4")
128
129
130
131
132
133
class QuantWeightNVFP4(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_fp4_weight

    @torch.no_grad()
Bilang ZHANG's avatar
Bilang ZHANG committed
134
    def load_fp4_weight(self, w, comfyui_mode=False):
135
136
137
138
139
140
141
        device = w.device
        w = w.cuda().to(torch.bfloat16)
        weight_global_scale = (2688.0 / torch.max(torch.abs(w))).to(torch.float32)
        w_q, scales = scaled_nvfp4_quant(w, weight_global_scale)
        w_q, scales = w_q.to(device), scales.to(device)
        self.extra["weight_global_scale"] = weight_global_scale.to(device)
        return w_q, scales, self.extra