quant.py 4.51 KB
Newer Older
xuwx1's avatar
xuwx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from abc import ABCMeta

import torch
# from qtorch.quant import float_quantize

from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER

try:
    from lightx2v_kernel.gemm import scaled_mxfp4_quant, scaled_mxfp6_quant, scaled_mxfp8_quant, scaled_nvfp4_quant
except ImportError:
    pass


class QuantTemplate(metaclass=ABCMeta):
    def __init__(self, weight):
        if weight.dim() != 2:
            raise ValueError(f"Only 2D tensors supported. Got {weight.dim()}D tensor")
        if torch.isnan(weight).any():
            raise ValueError("Tensor contains NaN values")

        self.weight_quant_func = None
        self.extra = {}


@CONVERT_WEIGHT_REGISTER("int8")
class QuantWeightINT8(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_int8_weight

    @torch.no_grad()
    def load_int8_weight(self, w, comfyui_mode=False):
        org_w_shape = w.shape
        if not comfyui_mode:
            max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
        else:
            max_val = w.abs().max()
        qmin, qmax = -128, 127
        scales = max_val / qmax
        w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w_q).sum() == 0

        if not comfyui_mode:
            scales = scales.view(org_w_shape[0], -1)
            w_q = w_q.reshape(org_w_shape)

        return w_q, scales, self.extra


@CONVERT_WEIGHT_REGISTER("fp8")
class QuantWeightFP8(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_fp8_weight

    @torch.no_grad()
    def load_fp8_weight(self, w, comfyui_mode=False):
        org_w_shape = w.shape
        if not comfyui_mode:
            max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
        else:
            max_val = w.abs().max()
        finfo = torch.finfo(torch.float8_e4m3fn)
        qmin, qmax = finfo.min, finfo.max
        scales = max_val / qmax
        scaled_tensor = w / scales
        scaled_tensor = torch.clip(scaled_tensor, qmin, qmax)
        w_q = float_quantize(scaled_tensor.float(), 4, 3, rounding="nearest").to(torch.float8_e4m3fn)

        assert torch.isnan(scales).sum() == 0
        assert torch.isnan(w_q).sum() == 0

        if not comfyui_mode:
            scales = scales.view(org_w_shape[0], -1)
            w_q = w_q.reshape(org_w_shape)

        return w_q, scales, self.extra


@CONVERT_WEIGHT_REGISTER("mxfp4")
class QuantWeightMxFP4(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_mxfp4_weight

    @torch.no_grad()
    def load_mxfp4_weight(self, w, comfyui_mode=False):
        device = w.device
        w = w.cuda().to(torch.bfloat16)
        w_q, scales = scaled_mxfp4_quant(w)
        w_q, scales = w_q.to(device), scales.to(device)
        return w_q, scales, self.extra


@CONVERT_WEIGHT_REGISTER("mxfp6")
class QuantWeightMxFP6(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_mxfp6_weight

    @torch.no_grad()
    def load_mxfp6_weight(self, w, comfyui_mode=False):
        device = w.device
        w = w.cuda().to(torch.bfloat16)
        w_q, scales = scaled_mxfp6_quant(w)
        w_q, scales = w_q.to(device), scales.to(device)
        return w_q, scales, self.extra


@CONVERT_WEIGHT_REGISTER("mxfp8")
class QuantWeightMxFP8(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_mxfp8_weight

    @torch.no_grad()
    def load_mxfp8_weight(self, w, comfyui_mode=False):
        device = w.device
        w = w.cuda().to(torch.bfloat16)
        w_q, scales = scaled_mxfp8_quant(w)
        w_q, scales = w_q.to(device), scales.to(device)
        return w_q, scales, self.extra


@CONVERT_WEIGHT_REGISTER("nvfp4")
class QuantWeightNVFP4(QuantTemplate):
    def __init__(self, weight):
        super().__init__(weight)
        self.weight_quant_func = self.load_fp4_weight

    @torch.no_grad()
    def load_fp4_weight(self, w, comfyui_mode=False):
        device = w.device
        w = w.cuda().to(torch.bfloat16)
        weight_global_scale = (2688.0 / torch.max(torch.abs(w))).to(torch.float32)
        w_q, scales = scaled_nvfp4_quant(w, weight_global_scale)
        w_q, scales = w_q.to(device), scales.to(device)
        self.extra["weight_global_scale"] = weight_global_scale.to(device)
        return w_q, scales, self.extra