mm_weight.py 54.6 KB
Newer Older
1
import re
helloyongyang's avatar
helloyongyang committed
2
from abc import ABCMeta, abstractmethod
PengGao's avatar
PengGao committed
3
4

import torch
Dongz's avatar
Dongz committed
5

PengGao's avatar
PengGao committed
6
from lightx2v.utils.envs import *
yihuiwen's avatar
yihuiwen committed
7
8
from lightx2v.utils.ggml_tensor import GGMLTensor
from lightx2v.utils.ggml_tensor import dequantize_tensor as gguf_dequantize_tensor
9
from lightx2v.utils.global_paras import CALIB
PengGao's avatar
PengGao committed
10
11
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
12
from lightx2v_platform.base.global_var import AI_DEVICE
PengGao's avatar
PengGao committed
13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
try:
    from lightx2v_kernel.gemm import (
        cutlass_scaled_mxfp4_mm,
        cutlass_scaled_mxfp6_mxfp8_mm,
        cutlass_scaled_mxfp8_mm,
        cutlass_scaled_nvfp4_mm,
        scaled_mxfp4_quant,
        scaled_mxfp6_quant,
        scaled_mxfp8_quant,
        scaled_nvfp4_quant,
    )
except ImportError:
    scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm = None, None
    scaled_mxfp4_quant, cutlass_scaled_mxfp4_mm = None, None
    scaled_mxfp6_quant, cutlass_scaled_mxfp6_mxfp8_mm = None, None
    scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm = None, None

gushiqiao's avatar
gushiqiao committed
31
32
33
34
35
36
37
38
39
40
try:
    from vllm import _custom_ops as ops
except ImportError:
    ops = None

try:
    import sgl_kernel
except ImportError:
    sgl_kernel = None

41
try:
gushiqiao's avatar
gushiqiao committed
42
    from q8_kernels.functional.linear import q8_linear
43
except ImportError:
gushiqiao's avatar
gushiqiao committed
44
45
46
47
48
49
    q8_linear = None

try:
    from q8_kernels.functional.linear import fp8_linear
except ImportError:
    fp8_linear = None
helloyongyang's avatar
helloyongyang committed
50

51
52
53
54
55
try:
    import deep_gemm
except ImportError:
    deep_gemm = None

gushiqiao's avatar
gushiqiao committed
56
try:
Wq-dd's avatar
Wq-dd committed
57
    from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
58
except ImportError:
gushiqiao's avatar
gushiqiao committed
59
60
    quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None

61
62
63
64
65
try:
    import gguf
except ImportError:
    gguf = None

66
67
try:
    import marlin_cuda_quant
68
except ImportError:
69
    marlin_cuda_quant = None
helloyongyang's avatar
helloyongyang committed
70

71

helloyongyang's avatar
helloyongyang committed
72
class MMWeightTemplate(metaclass=ABCMeta):
73
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
helloyongyang's avatar
helloyongyang committed
74
75
        self.weight_name = weight_name
        self.bias_name = bias_name
76
        self.create_cuda_buffer = create_cuda_buffer
77
        self.create_cpu_buffer = create_cpu_buffer
gushiqiao's avatar
fix.  
gushiqiao committed
78
79
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
80
        self.is_post_adapter = is_post_adapter
helloyongyang's avatar
helloyongyang committed
81
82
83
84
85
86
87
        self.config = {}

    @abstractmethod
    def load(self, weight_dict):
        pass

    @abstractmethod
88
    def apply(self):
helloyongyang's avatar
helloyongyang committed
89
90
        pass

91
92
    def set_config(self, config={}):
        self.config = config
helloyongyang's avatar
helloyongyang committed
93

gushiqiao's avatar
gushiqiao committed
94
    def to_cuda(self, non_blocking=False):
95
        self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
gushiqiao's avatar
gushiqiao committed
96
        if hasattr(self, "pin_weight_scale"):
97
            self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking)
gushiqiao's avatar
gushiqiao committed
98
        if hasattr(self, "pin_bias") and self.pin_bias is not None:
99
            self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
gushiqiao's avatar
gushiqiao committed
100

101
    def to_cpu(self, non_blocking=False):
gushiqiao's avatar
gushiqiao committed
102
103
104
105
106
107
108
109
110
111
112
113
        if hasattr(self, "pin_weight"):
            self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
            if hasattr(self, "weight_scale_name"):
                self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
            if self.bias is not None:
                self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
        else:
            self.weight = self.weight.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "weight_scale"):
                self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "bias") and self.bias is not None:
                self.bias = self.bias.to("cpu", non_blocking=non_blocking)
114

helloyongyang's avatar
helloyongyang committed
115

Dongz's avatar
Dongz committed
116
@MM_WEIGHT_REGISTER("Default")
helloyongyang's avatar
helloyongyang committed
117
class MMWeight(MMWeightTemplate):
118
119
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
helloyongyang's avatar
helloyongyang committed
120
121

    def load(self, weight_dict):
122
        if self.create_cuda_buffer:
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            self._load_cuda_buffers(weight_dict)
        elif self.create_cpu_buffer:
            self._load_cpu_pin_buffers()
        else:
            self._load_default_tensors(weight_dict)

    def _get_source_tensor(self, source_name, weight_dict=None):
        if self.lazy_load:
            return self.lazy_load_file.get_tensor(source_name)
        return weight_dict[source_name]

    def _create_pin_tensor(self, tensor, transpose=False):
        pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
        pin_tensor = pin_tensor.copy_(tensor)
        if transpose:
            pin_tensor = pin_tensor.t()
        del tensor
        return pin_tensor

    def _load_cuda_buffers(self, weight_dict):
        self.weight_cuda_buffer = self._get_source_tensor(self.weight_name, weight_dict).t().to(AI_DEVICE)
        if self.bias_name is not None:
            self.bias_cuda_buffer = self._get_source_tensor(self.bias_name, weight_dict).to(AI_DEVICE)

    def _load_cpu_pin_buffers(self):
        weight_tensor = self.lazy_load_file.get_tensor(self.weight_name)
        self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True)

        if self.bias_name is not None:
            bias_tensor = self.lazy_load_file.get_tensor(self.bias_name)
            self.pin_bias = self._create_pin_tensor(bias_tensor)
154
        else:
155
156
157
158
159
            self.bias = None
            self.pin_bias = None

    def _load_default_tensors(self, weight_dict):
        if not self.lazy_load:
160
            device = weight_dict[self.weight_name].device
161
            if device.type == "cpu":
162
163
                weight_tensor = weight_dict[self.weight_name]
                self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True)
164
165

                if self.bias_name is not None:
166
167
                    bias_tensor = weight_dict[self.bias_name]
                    self.pin_bias = self._create_pin_tensor(bias_tensor)
168
169
170
171
                else:
                    self.bias = None
                    self.pin_bias = None
                del weight_dict[self.weight_name]
172
            else:
173
                self.weight = weight_dict[self.weight_name].t()
174
                self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
175

helloyongyang's avatar
helloyongyang committed
176
177
178
179
180
181
182
183
184
    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)

helloyongyang's avatar
helloyongyang committed
185
186
187
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
188
189
190
        destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
        if self.bias_name is not None:
            destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
helloyongyang's avatar
helloyongyang committed
191
192
        return destination

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            assert adapter_block_index is not None
            self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
        else:
            self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)

        weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).t()
        self.pin_weight = self.pin_weight.copy_(weight_tensor)
        del weight_tensor

        if self.bias_name is not None:
            if self.is_post_adapter:
                assert adapter_block_index is not None
                self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
            else:
                self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)

            bias_tensor = self.lazy_load_file.get_tensor(self.bias_name)
            self.pin_bias.copy_(bias_tensor)
            del bias_tensor

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    def load_state_dict(self, destination, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            assert adapter_block_index is not None
            weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
        else:
            weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)

        if weight_name not in destination:
            self.weight = None
            return

        self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)

        if self.bias_name is not None:
            if self.is_post_adapter:
                assert adapter_block_index is not None
                bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
            else:
                bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
            self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
        else:
            self.bias = None

helloyongyang's avatar
helloyongyang committed
238

Dongz's avatar
Dongz committed
239
@MM_WEIGHT_REGISTER("Default-Force-FP32")
240
class MMWeightForceFP32(MMWeight):
241
242
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
helloyongyang's avatar
helloyongyang committed
243
244

    def load(self, weight_dict):
245
246
247
248
249
        if not self.lazy_load:
            super().load(weight_dict)
            self.weight = self.weight.to(torch.float32)
            if hasattr(self, "bias") and self.bias is not None:
                self.bias = self.bias.to(torch.float32)
helloyongyang's avatar
helloyongyang committed
250
251


252
class MMWeightQuantTemplate(MMWeightTemplate):
253
254
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
255
        self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
256
257
258
        self.load_func = None
        self.weight_need_transpose = True
        self.act_quant_func = None
259
260
        self.lazy_load = lazy_load
        self.lazy_load_file = lazy_load_file
261
        self.infer_dtype = GET_DTYPE()
262
        self.bias_force_fp32 = False
263

helloyongyang's avatar
helloyongyang committed
264
265
266
    # =========================
    # weight load functions
    # =========================
267
268
269
270
271
272
273
274
275
    def load(self, weight_dict):
        self.load_quantized(weight_dict)
        if self.weight_need_transpose:
            if hasattr(self, "weight") and self.weight is not None:
                self.weight = self.weight.t()
            if hasattr(self, "pin_weight") and self.pin_weight is not None:
                self.pin_weight = self.pin_weight.t()
            if hasattr(self, "weight_cuda_buffer") and self.weight_cuda_buffer is not None:
                self.weight_cuda_buffer = self.weight_cuda_buffer.t()
276

277
278
279
280
281
    def load_quantized(self, weight_dict):
        if self.create_cuda_buffer:
            self._load_cuda_buffers(weight_dict)
        elif self.create_cpu_buffer:
            self._load_cpu_pin_buffers()
282
        else:
283
            self._load_default_tensors(weight_dict)
284

285
286
287
288
    def _load_cuda_buffers(self, weight_dict):
        source = self.lazy_load_file if self.lazy_load else weight_dict
        self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
        self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
289

290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    def _get_cuda_tensor_pair(self, source, is_lazy):
        if is_lazy:
            weight = source.get_tensor(self.weight_name).to(AI_DEVICE)
            scale = source.get_tensor(self.weight_scale_name).float().to(AI_DEVICE)
        else:
            weight = source[self.weight_name].to(AI_DEVICE)
            scale = source[self.weight_scale_name].float().to(AI_DEVICE)
        return weight, scale

    def _get_cuda_bias_tensor(self, source, is_lazy):
        if self.bias_name is None:
            return None
        if is_lazy:
            bias = source.get_tensor(self.bias_name)
            dtype = self.infer_dtype
        else:
            bias = source[self.bias_name]
            dtype = bias.dtype
        if self.bias_force_fp32:
            bias = bias.to(torch.float32)
        else:
            bias = bias.to(dtype)
        return bias.to(AI_DEVICE)

    def _load_cpu_pin_buffers(self):
        self.pin_weight, self.pin_weight_scale = self._get_cpu_pin_tensor_pair(self.lazy_load_file, is_lazy=True)
        self.pin_bias = self._get_cpu_pin_bias_tensor(self.lazy_load_file, is_lazy=True)
        self.bias = None

    def _get_cpu_pin_tensor_pair(self, source, is_lazy):
        if is_lazy:
            weight_tensor = source.get_tensor(self.weight_name)
            scale_tensor = source.get_tensor(self.weight_scale_name)
            scale_dtype = torch.float
        else:
            weight_tensor = source[self.weight_name]
            scale_tensor = source[self.weight_scale_name]
            scale_dtype = torch.float

        pin_weight = self._create_pin_tensor(weight_tensor)
        pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
        return pin_weight, pin_scale

    def _get_cpu_pin_bias_tensor(self, source, is_lazy):
        if self.bias_name is None:
            return None
        if is_lazy:
            bias_tensor = source.get_tensor(self.bias_name)
            if not self.bias_force_fp32:
                bias_tensor = bias_tensor.to(self.infer_dtype)
        else:
            bias_tensor = source[self.bias_name]
        if self.bias_force_fp32:
            bias_tensor = bias_tensor.to(torch.float32)
        return self._create_pin_tensor(bias_tensor)

    def _create_pin_tensor(self, tensor, dtype=None):
        dtype = dtype or tensor.dtype
        pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
        pin_tensor.copy_(tensor)
        del tensor
        return pin_tensor

    def _load_default_tensors(self, weight_dict):
354
        if not self.lazy_load:
355
356
357
358
359
            self.weight, self.weight_scale, self.pin_weight, self.pin_weight_scale = self._get_device_tensor_pair(weight_dict)
            self._load_default_bias(weight_dict)
        else:
            self.bias = None
            self.pin_bias = None
360

361
362
363
364
365
    def _get_device_tensor_pair(self, source):
        device = source[self.weight_name].device
        if device.type == "cpu":
            pin_weight, pin_scale = self._get_cpu_pin_tensor_pair(source, is_lazy=False)
            return None, None, pin_weight, pin_scale
366
        else:
367
            return source[self.weight_name], source[self.weight_scale_name].float(), None, None
368

369
370
371
372
373
374
    def _load_default_bias(self, source):
        if self.bias_name is None:
            self.bias = None
            self.pin_bias = None
            self.bias_cuda_buffer = None
            return
375

376
377
        if self.create_cuda_buffer:
            self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, is_lazy=False)
378
379
            self.bias = None
            self.pin_bias = None
380
381
382
383
384
385
386
387
388
        else:
            bias_tensor = source[self.bias_name].float() if self.bias_force_fp32 else source[self.bias_name]
            device = bias_tensor.device
            if device.type == "cpu":
                self.pin_bias = self._get_cpu_pin_bias_tensor(source, is_lazy=False)
                self.bias = None
            else:
                self.bias = bias_tensor
                self.pin_bias = None
389
390

    def load_fp8_perchannel_sym(self, weight_dict):
391
        if self.config.get("weight_auto_quant", False):
392
            self.weight = weight_dict[self.weight_name].to(torch.float32)
393
394
395
396
397
398
            w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.float8_e4m3fn)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.load_quantized(weight_dict)
399

400
    def load_int8_perchannel_sym(self, weight_dict):
401
        if self.config.get("weight_auto_quant", False):
402
            self.weight = weight_dict[self.weight_name].to(torch.float32)
403
404
405
406
407
408
            w_quantizer = IntegerQuantizer(8, True, "per_channel")
            self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
            self.weight = self.weight.to(torch.int8)
            self.weight_scale = self.weight_scale.to(torch.float32)
        else:
            self.load_quantized(weight_dict)
409

410
411
412
    def load_mxfp4(self, weight_dict):
        if self.config.get("weight_auto_quant", False):
            device = weight_dict[self.weight_name].device
413
            self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
414
415
416
417
            self.weight, self.weight_scale = scaled_mxfp4_quant(self.weight)
            self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
        else:
            device = weight_dict[self.weight_name].device
418
            if device.type == "cpu":
419
420
421
422
423
424
425
426
427
428
429
                weight_shape = weight_dict[self.weight_name].shape
                weight_dtype = weight_dict[self.weight_name].dtype
                self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
                self.pin_weight.copy_(weight_dict[self.weight_name])

                weight_scale_shape = weight_dict[self.weight_scale_name].shape
                weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
                self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
                self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
                del weight_dict[self.weight_name]
            else:
430
431
                self.weight = weight_dict[self.weight_name]
                self.weight_scale = weight_dict[self.weight_scale_name]
432
433
434
435

    def load_mxfp6(self, weight_dict):
        if self.config.get("weight_auto_quant", False):
            device = weight_dict[self.weight_name].device
436
            self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
437
438
439
440
            self.weight, self.weight_scale = scaled_mxfp6_quant(self.weight)
            self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
        else:
            device = weight_dict[self.weight_name].device
441
            if device.type == "cpu":
442
443
444
445
446
447
448
449
450
451
452
                weight_shape = weight_dict[self.weight_name].shape
                weight_dtype = weight_dict[self.weight_name].dtype
                self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
                self.pin_weight.copy_(weight_dict[self.weight_name])

                weight_scale_shape = weight_dict[self.weight_scale_name].shape
                weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
                self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
                self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
                del weight_dict[self.weight_name]
            else:
453
454
                self.weight = weight_dict[self.weight_name]
                self.weight_scale = weight_dict[self.weight_scale_name]
455
456
457
458

    def load_mxfp8(self, weight_dict):
        if self.config.get("weight_auto_quant", False):
            device = weight_dict[self.weight_name].device
459
            self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
460
461
462
463
            self.weight, self.weight_scale = scaled_mxfp8_quant(self.weight)
            self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
        else:
            device = weight_dict[self.weight_name].device
464
            if device.type == "cpu":
465
466
467
468
469
470
471
472
473
474
475
                weight_shape = weight_dict[self.weight_name].shape
                weight_dtype = weight_dict[self.weight_name].dtype
                self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
                self.pin_weight.copy_(weight_dict[self.weight_name])

                weight_scale_shape = weight_dict[self.weight_scale_name].shape
                weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
                self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
                self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
                del weight_dict[self.weight_name]
            else:
476
477
                self.weight = weight_dict[self.weight_name]
                self.weight_scale = weight_dict[self.weight_scale_name]
478
479
480
481
482
483
484
485
486

    def load_nvfp4(self, weight_dict):
        device = weight_dict[self.weight_name].device

        input_absmax = weight_dict[self.weight_name.replace(".weight", ".input_absmax")]
        input_global_scale = (2688.0 / input_absmax).to(torch.float32)
        weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"]
        alpha = 1.0 / (input_global_scale * weight_global_scale)

487
        if device.type == "cpu":
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
            weight_shape = weight_dict[self.weight_name].shape
            weight_dtype = weight_dict[self.weight_name].dtype
            self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
            self.pin_weight.copy_(weight_dict[self.weight_name])

            weight_scale_shape = weight_dict[self.weight_scale_name].shape
            weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
            self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
            self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])

            input_global_scale_shape = input_global_scale.shape
            input_global_scale_dtype = input_global_scale.dtype
            self.pin_input_global_scale = torch.empty(input_global_scale_shape, pin_memory=True, dtype=input_global_scale_dtype)
            self.pin_input_global_scale.copy_(input_global_scale)

            alpha_shape = alpha.shape
            alpha_dtype = alpha.dtype
            self.pin_alpha = torch.empty(alpha_shape, pin_memory=True, dtype=alpha_dtype)
            self.pin_alpha.copy_(alpha)

            del weight_dict[self.weight_name]
        else:
510
511
512
513
            self.weight = weight_dict[self.weight_name]
            self.weight_scale = weight_dict[self.weight_scale_name]
            self.input_global_scale = input_global_scale
            self.alpha = alpha
514

Gu Shiqiao's avatar
Gu Shiqiao committed
515
516
        if self.bias_name is not None:
            if self.create_cuda_buffer:
517
                self.bias_cuda_buffer = weight_dict[self.bias_name].to(AI_DEVICE)
Gu Shiqiao's avatar
Gu Shiqiao committed
518
519
            else:
                device = weight_dict[self.bias_name].device
520
                if device.type == "cpu":
Gu Shiqiao's avatar
Gu Shiqiao committed
521
522
523
524
525
                    bias_shape = weight_dict[self.bias_name].shape
                    bias_dtype = weight_dict[self.bias_name].dtype
                    self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
                    self.pin_bias.copy_(weight_dict[self.bias_name])
                else:
526
                    self.bias = weight_dict[self.bias_name]
Gu Shiqiao's avatar
Gu Shiqiao committed
527
528
529
530
        else:
            self.bias = None
            self.pin_bias = None

531
    def load_fp8_perblock128_sym(self, weight_dict):
532
        if self.config.get("weight_auto_quant", False):
533
            self.weight = weight_dict[self.weight_name]
534
535
536
            self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
        else:
            self.load_quantized(weight_dict)
537

538
539
540
    def per_block_cast_to_fp8(self, x):
        assert x.dim() == 2
        m, n = x.shape
541
542
543
544
545
        x_padded = torch.zeros(
            (deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
            dtype=x.dtype,
            device=x.device,
        )
546
547
548
549
550
551
        x_padded[:m, :n] = x
        x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
        x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
        x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
        return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))

helloyongyang's avatar
helloyongyang committed
552
553
554
    # =========================
    # act quant kernels
    # =========================
gushiqiao's avatar
gushiqiao committed
555
556
557
    def act_quant_int8_perchannel_sym_torchao(self, x):
        input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
        return input_tensor_quant, input_tensor_scale
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573

    def act_quant_fp8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
        return input_tensor_quant, input_tensor_scale

    def act_quant_fp8_perchannel_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
        sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_int8_perchannel_sym_vllm(self, x):
        input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
        return input_tensor_quant, input_tensor_scale

574
575
576
577
578
579
580
581
582
583
584
585
    def act_quant_nvfp4(self, x):
        input_tensor_quant, input_tensor_scale = scaled_nvfp4_quant(x, self.input_global_scale)
        return input_tensor_quant, input_tensor_scale

    def act_quant_mxfp4(self, x):
        input_tensor_quant, input_tensor_scale = scaled_mxfp4_quant(x)
        return input_tensor_quant, input_tensor_scale

    def act_quant_mxfp8(self, x):
        input_tensor_quant, input_tensor_scale = scaled_mxfp8_quant(x)
        return input_tensor_quant, input_tensor_scale

586
587
588
589
590
591
592
593
594
595
596
    def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
        assert x.dim() == 2 and x.size(1) % 128 == 0
        m, n = x.shape
        x_view = x.view(m, -1, 128)
        x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
        return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)

    def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
        m, k = x.shape
        input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
        input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
597
598
599
600
601
602
603
604
605
        sgl_kernel.sgl_per_token_group_quant_fp8(
            x,
            input_tensor_quant,
            input_tensor_scale,
            group_size=128,
            eps=1e-10,
            fp8_min=-448.0,
            fp8_max=448.0,
        )
606
607
        return input_tensor_quant, input_tensor_scale

608
609
610
    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
611
612
613
614
        destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
        if self.bias_name is not None:
            destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
        destination[self.weight_scale_name] = self.pin_weight_scale if hasattr(self, "pin_weight_scale") else self.weight_scale
615
616
        return destination

617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
    def load_state_dict(self, destination, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
            weight_scale_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_scale_name, count=1)
        else:
            weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
            weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1)

        if weight_name not in destination:
            self.weight = None
            return

        self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
        self.weight_scale = self.weight_scale_cuda_buffer.copy_(destination[weight_scale_name], non_blocking=True)

        if self.bias_name is not None:
            bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
            self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
        else:
            self.bias = None

638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
    def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
            self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_scale_name, count=1)
        else:
            self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
            self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1)

        if self.weight_need_transpose:
            weight_tensor = self.lazy_load_file.get_tensor(self.weight_name).t()
        else:
            weight_tensor = self.lazy_load_file.get_tensor(self.weight_name)
        self.pin_weight = self.pin_weight.copy_(weight_tensor)

        weight_scale_tensor = self.lazy_load_file.get_tensor(self.weight_scale_name)
        self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor)

        del weight_tensor

        if self.bias_name is not None:
            if self.is_post_adapter:
                assert adapter_block_index is not None
                self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
            else:
                self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)

            bias_tensor = self.lazy_load_file.get_tensor(self.bias_name)
            self.pin_bias.copy_(bias_tensor)
            del bias_tensor

668

669
@MM_WEIGHT_REGISTER("fp8-vllm")
670
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
671
    """
helloyongyang's avatar
helloyongyang committed
672
673
674
675
676
677
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
678
679
    """

680
681
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
682
683
684
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
685
686
687
688
689
690

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
691
692

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
693
694
695
696
697
698
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
gushiqiao's avatar
gushiqiao committed
699
            self.bias if self.bias is not None else None,
700
        )
helloyongyang's avatar
helloyongyang committed
701
702
703
        return output_tensor


704
@MM_WEIGHT_REGISTER("int8-vllm")
705
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
706
    """
helloyongyang's avatar
helloyongyang committed
707
708
709
710
711
712
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: vllm
Dongz's avatar
Dongz committed
713
714
    """

715
716
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
717
718
719
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
helloyongyang's avatar
helloyongyang committed
720
721
722
723
724

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
725
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
726
727

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
728
729
730
731
732
733
        torch.ops._C.cutlass_scaled_mm(
            output_tensor,
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
gushiqiao's avatar
gushiqiao committed
734
            self.bias if self.bias is not None else None,
735
        )
helloyongyang's avatar
helloyongyang committed
736
737
738
        return output_tensor


739
740
741
742
743
744
745
746
747
748
@MM_WEIGHT_REGISTER("mxfp4")
class MMWeightWmxfp4Amxfp4dynamic(MMWeightQuantTemplate):
    """
    Name: W-mxfp4-A-mxfp4-dynamic

    Quant MM:
        Weight: mxfp4
        Act: mxfp4
    """

749
750
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
        self.load_func = self.load_mxfp4
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_mxfp4
        self.set_alpha()

    def set_alpha(self):
        self.alpha = torch.tensor(1.0, dtype=torch.float32)

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        self.alpha = self.alpha.to(self.weight.device)
        output_tensor = cutlass_scaled_mxfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("mxfp6-mxfp8")
class MMWeightWmxfp6Amxfp8dynamic(MMWeightQuantTemplate):
    """
    Name: W-mxfp6-A-nvfp8-dynamic

    Quant MM:
        Weight: mxfp6
        Act: mxfp8
    """

776
777
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
        self.load_func = self.load_mxfp6
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_mxfp8
        self.set_alpha()

    def set_alpha(self):
        self.alpha = torch.tensor(1.0, dtype=torch.float32)

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        self.alpha = self.alpha.to(self.weight.device)
        output_tensor = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("mxfp8")
class MMWeightWmxfp8Amxfp8dynamic(MMWeightQuantTemplate):
    """
    Name: W-mxfp8-A-nvfp8-dynamic

    Quant MM:
        Weight: mxfp8
        Act: mxfp8
    """

803
804
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
        self.load_func = self.load_mxfp8
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_mxfp8
        self.set_alpha()

    def set_alpha(self):
        self.alpha = torch.tensor(1.0, dtype=torch.float32)

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        self.alpha = self.alpha.to(self.weight.device)
        output_tensor = cutlass_scaled_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
        return output_tensor


@MM_WEIGHT_REGISTER("nvfp4")
class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
    """
    Name: W-nvfp4-A-nvfp4-dynamic

    Quant MM:
        Weight: nvfp4
        Act: nvfp4
    """

830
831
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
832
833
834
835
836
837
838
839
840
841
        self.load_func = self.load_nvfp4
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_nvfp4

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
        output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
        return output_tensor

    def to_cuda(self, non_blocking=False):
842
        self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
843
        if hasattr(self, "pin_weight_scale"):
844
845
846
            self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking)
            self.input_global_scale = self.pin_input_global_scale.to(AI_DEVICE, non_blocking=non_blocking)
            self.alpha = self.pin_alpha.to(AI_DEVICE, non_blocking=non_blocking)
847
        if hasattr(self, "pin_bias") and self.pin_bias is not None:
848
            self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877

    def to_cpu(self, non_blocking=False):
        if hasattr(self, "pin_weight"):
            self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
            if hasattr(self, "weight_scale_name"):
                self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
                self.input_global_scale = self.pin_input_global_scale.copy_(self.input_global_scale, non_blocking=non_blocking).cpu()
                self.alpha = self.pin_alpha.copy_(self.alpha, non_blocking=non_blocking).cpu()
            if self.bias is not None:
                self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
        else:
            self.weight = self.weight.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "weight_scale"):
                self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
                self.input_global_scale = self.input_global_scale.to("cpu", non_blocking=non_blocking)
                self.alpha = self.alpha.to("cpu", non_blocking=non_blocking)
            if hasattr(self, "bias") and self.bias is not None:
                self.bias = self.bias.to("cpu", non_blocking=non_blocking)


@MM_WEIGHT_REGISTER("Calib")
class MMCalibNvfp4(MMWeight):
    """
    Name: calib

    Calib:
        absmax: torch.max(torch.abs(input_tensor))
    """

878
879
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
        self.running_absmax = None
        self.count = 0
        self.decay = 0.9

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype, device = input_tensor.dtype, input_tensor.device

        current_absmax = torch.max(torch.abs(input_tensor)).to("cpu")
        if self.count % 2 == 0:
            if self.running_absmax is None:
                self.running_absmax = current_absmax
            else:
                self.running_absmax = self.decay * self.running_absmax + (1 - self.decay) * current_absmax
            CALIB["absmax"][self.weight_name] = self.running_absmax
        self.count = self.count + 1

        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
        if self.bias is None:
            return torch.mm(input_tensor, self.weight, out=output_tensor)
        return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)


903
@MM_WEIGHT_REGISTER("fp8-q8f")
904
905
906
907
908
909
910
911
912
913
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Q8F
    """

914
915
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
916
917
918
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
919
        self.bias_force_fp32 = True
920
921
922

    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
gushiqiao's avatar
gushiqiao committed
923
        output_tensor = fp8_linear(
924
925
            input_tensor_quant,
            self.weight,
gushiqiao's avatar
gushiqiao committed
926
            self.bias.float() if self.bias is not None else None,
927
928
            input_tensor_scale,
            self.weight_scale,
929
            out_dtype=self.infer_dtype,
930
        )
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
931
        return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
932
933


934
@MM_WEIGHT_REGISTER("int8-q8f")
935
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Dongz's avatar
Dongz committed
936
    """
937
938
939
940
941
942
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Q8F
Dongz's avatar
Dongz committed
943
944
    """

945
946
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
947
948
949
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
950

951
952
    def apply(self, input_tensor):
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
gushiqiao's avatar
gushiqiao committed
953
        output_tensor = q8_linear(
954
955
            input_tensor_quant,
            self.weight,
gushiqiao's avatar
gushiqiao committed
956
            self.bias.float() if self.bias is not None else None,
957
958
959
            input_tensor_scale,
            self.weight_scale,
            fuse_gelu=False,
960
            out_dtype=self.infer_dtype,
961
        )
Yang Yong (雍洋)'s avatar
Yang Yong (雍洋) committed
962
        return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
963
964


965
@MM_WEIGHT_REGISTER("fp8-b128-deepgemm")
966
967
968
969
970
971
972
973
974
975
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl

    Quant MM:
        Weight: fp8 perblock 128x128 sym
        Act: fp8 pertoken-pergroup group=128 dynamic sym
        Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
    """

976
977
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
978
979
980
981
982
983
984
985
986
987
988
        self.load_func = self.load_fp8_perblock128_sym
        self.weight_need_transpose = False
        self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[0])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
989
990
991
992
993
994
        deep_gemm.gemm_fp8_fp8_bf16_nt(
            (input_tensor_quant, input_tensor_scale),
            (self.weight, self.weight_scale),
            output_tensor,
        )
        if hasattr(self, "bias") and self.bias is not None:
995
996
997
998
            output_tensor.add_(self.bias)
        return output_tensor


999
@MM_WEIGHT_REGISTER("fp8-sgl")
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
    """
    Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl

    Quant MM:
        Weight: fp8 perchannel sym
        Act: fp8 perchannel dynamic sym
        Kernel: Sgl-kernel
    """

1010
1011
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
1012
1013
1014
        self.load_func = self.load_fp8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
1015
1016

    def apply(self, input_tensor):
1017
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
1018
1019
1020
1021
1022
        output_tensor = sgl_kernel.fp8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
1023
            self.infer_dtype,
1024
            self.bias if self.bias is not None else None,
1025
        )
1026
1027
1028
        return output_tensor


1029
@MM_WEIGHT_REGISTER("int8-sgl")
helloyongyang's avatar
helloyongyang committed
1030
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
1031
1032
1033
1034
1035
1036
1037
1038
1039
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
    """

1040
1041
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm

    def apply(self, input_tensor):
        shape = (input_tensor.shape[0], self.weight.shape[1])
        dtype = input_tensor.dtype
        device = input_tensor.device
        output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)

        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
1053
1054
1055
1056
1057
        output_tensor = sgl_kernel.int8_scaled_mm(
            input_tensor_quant,
            self.weight,
            input_tensor_scale,
            self.weight_scale,
1058
            self.infer_dtype,
gushiqiao's avatar
gushiqiao committed
1059
            self.bias if self.bias is not None else None,
1060
        )
1061
        return output_tensor
1062
1063


1064
@MM_WEIGHT_REGISTER("int8-torchao")
1065
class MMWeightWint8channelAint8channeldynamicTorchao(MMWeightQuantTemplate):
gushiqiao's avatar
gushiqiao committed
1066
1067
1068
1069
1070
1071
1072
1073
1074
    """
    Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao

    Quant MM:
        Weight: int8 perchannel sym
        Act: int8 perchannel dynamic sym
        Kernel: Torchao
    """

1075
1076
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
gushiqiao's avatar
gushiqiao committed
1077
1078
1079
1080
1081
1082
1083
        self.load_func = self.load_int8_perchannel_sym
        self.weight_need_transpose = True
        self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao

    def apply(self, input_tensor):
        input_tensor = input_tensor
        input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
1084
        output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
gushiqiao's avatar
gushiqiao committed
1085
1086
1087
1088
1089
1090
        if self.bias is not None:
            output_tensor = output_tensor + self.bias

        return output_tensor


yihuiwen's avatar
yihuiwen committed
1091
class MMWeightGGUFTemplate(MMWeightTemplate):
1092
1093
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
1094

yihuiwen's avatar
yihuiwen committed
1095
    def load(self, weight_dict):
1096
1097
1098
        if not self.lazy_load:
            assert not self.create_cuda_buffer, "GGUF Unsupported offload block"
            self.weight = weight_dict[self.weight_name]
1099

1100
1101
            weight_shape = self.weight.shape
            weight_dtype = self.weight.dtype
1102

1103
1104
1105
1106
1107
1108
            if isinstance(self.weight, GGMLTensor):
                self.pin_weight = GGMLTensor.empty_pinned(weight_shape, orig_shape=self.weight.orig_shape, dtype=weight_dtype, gguf_type=self.weight.gguf_type)
                self.pin_weight.copy_from(self.weight)
            else:
                self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
                self.pin_weight.copy_(weight_dict[self.weight_name])
yihuiwen's avatar
yihuiwen committed
1109

1110
1111
1112
1113
1114
1115
1116
1117
            if self.bias_name is not None:
                self.bias = weight_dict[self.bias_name]
                if isinstance(self.bias, GGMLTensor):
                    self.pin_bias = GGMLTensor.empty_pinned(self.bias.shape, orig_shape=self.bias.orig_shape, dtype=self.bias.dtype, gguf_type=self.bias.gguf_type)
                    self.pin_bias.copy_from(self.bias)
                else:
                    self.pin_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
                    self.pin_bias.copy_(weight_dict[self.bias_name])
yihuiwen's avatar
yihuiwen committed
1118
            else:
1119
                self.bias = None
yihuiwen's avatar
yihuiwen committed
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242

    def load_state_dict(self, destination, block_index, adapter_block_index=None):
        if self.is_post_adapter:
            assert adapter_block_index is not None
            weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
        else:
            weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)

        if weight_name not in destination:
            self.weight = None
            return

        self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)

        if self.bias_name is not None:
            if self.is_post_adapter:
                assert adapter_block_index is not None
                bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
            else:
                bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
            self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
        else:
            self.bias = None

    def state_dict(self, destination=None):
        if destination is None:
            destination = {}
        destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
        if self.bias_name is not None:
            destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias

        return destination

    def get_weight(self, tensor, dtype):
        if tensor is None:
            return

        weight = gguf_dequantize_tensor(tensor, dtype)
        if isinstance(weight, GGMLTensor):
            weight = torch.Tensor(weight)

        return weight

    def cast_bias_weight(self, input_tensor=None, dtype=None, device=None, bias_dtype=None):
        if input_tensor is not None:
            if dtype is None:
                dtype = getattr(input_tensor, "dtype", torch.float32)

        bias = None
        if self.bias is not None:
            bias = self.get_weight(self.bias, dtype)

        weight = self.get_weight(self.weight, dtype)
        return weight, bias

    def apply(self, input_tensor):
        weight, bias = self.cast_bias_weight(input_tensor)
        return torch.nn.functional.linear(input_tensor, weight, bias)


@MM_WEIGHT_REGISTER("gguf-BF16")
class MMWeightGGUFBF16(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.BF16


@MM_WEIGHT_REGISTER("gguf-Q8_0")
class MMWeightGGUFQ80(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q8_0


@MM_WEIGHT_REGISTER("gguf-Q6_K")
class MMWeightGGUFQ6K(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q6_K


@MM_WEIGHT_REGISTER("gguf-Q5_K_S")
class MMWeightGGUFQ5KS(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q6_K


@MM_WEIGHT_REGISTER("gguf-Q5_K_M")
class MMWeightGGUFQ5KM(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q6_K


@MM_WEIGHT_REGISTER("gguf-Q5_1")
class MMWeightGGUFQ51(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q5_1


@MM_WEIGHT_REGISTER("gguf-Q5_0")
class MMWeightGGUFQ50(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q5_0


@MM_WEIGHT_REGISTER("gguf-Q4_K_M")
class MMWeightGGUFQ4KM(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q5_0


@MM_WEIGHT_REGISTER("gguf-Q4_K_S")
class MMWeightGGUFQ4KS(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q4_K


@MM_WEIGHT_REGISTER("gguf-Q4_1")
class MMWeightGGUFQ41(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q4_1


@MM_WEIGHT_REGISTER("gguf-Q4_0")
class MMWeightGGUFQ40(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q4_0


@MM_WEIGHT_REGISTER("gguf-Q3_K_M")
class MMWeightGGUFQ3KM(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q3_K


@MM_WEIGHT_REGISTER("gguf-Q3_K_S")
class MMWeightGGUFQ3KS(MMWeightGGUFTemplate):
    qtype = gguf.GGMLQuantizationType.Q2_K
1243

1244

1245
@MM_WEIGHT_REGISTER("int4-g128-marlin")
1246
1247
1248
1249
1250
1251
1252
1253
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
    """
    Name: "W-int4-group128-sym-Marlin

    Quant int4 x FP16:
        Weight: int4 pergroup sym
        Kernel: Marlin
    """
1254

1255
1256
    def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
        super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
1257
1258
1259
1260
1261
1262
        self.load_func = self.load_quantized

    def load(self, weight_dict):
        assert not self.lazy_load
        self.load_func(weight_dict)
        self.workspace = weight_dict[f"{self.weight_name}_workspace"]
gushiqiao's avatar
gushiqiao committed
1263

1264
        if self.bias_name is not None:
gushiqiao's avatar
gushiqiao committed
1265
1266
            bias_shape = weight_dict[self.bias_name].shape
            bias_dtype = weight_dict[self.bias_name].dtype
1267
1268
            self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
            self.bias.copy_(weight_dict[self.bias_name])
1269
1270
        else:
            self.bias = None
1271

1272
1273
1274
1275
1276
1277
    def apply(self, input_tensor):
        output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
        marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
        if hasattr(self, "bias") and self.bias is not None:
            output_tensor.add_(self.bias)
        return output_tensor